1 //===- OCL20ToSPIRV.cpp - Transform OCL20 to SPIR-V builtins -----*- C++ -*-===//
2 //
3 //                     The LLVM/SPIRV Translator
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 // Copyright (c) 2014 Advanced Micro Devices, Inc. All rights reserved.
9 //
10 // Permission is hereby granted, free of charge, to any person obtaining a
11 // copy of this software and associated documentation files (the "Software"),
12 // to deal with the Software without restriction, including without limitation
13 // the rights to use, copy, modify, merge, publish, distribute, sublicense,
14 // and/or sell copies of the Software, and to permit persons to whom the
15 // Software is furnished to do so, subject to the following conditions:
16 //
17 // Redistributions of source code must retain the above copyright notice,
18 // this list of conditions and the following disclaimers.
19 // Redistributions in binary form must reproduce the above copyright notice,
20 // this list of conditions and the following disclaimers in the documentation
21 // and/or other materials provided with the distribution.
22 // Neither the names of Advanced Micro Devices, Inc., nor the names of its
23 // contributors may be used to endorse or promote products derived from this
24 // Software without specific prior written permission.
25 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
26 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
27 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
28 // CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
29 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
30 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH
31 // THE SOFTWARE.
32 //
33 //===----------------------------------------------------------------------===//
34 //
35 // This file implements translation of OCL20 builtin functions.
36 //
37 //===----------------------------------------------------------------------===//
38 #define DEBUG_TYPE "cl20tospv"
39 
40 #include "SPIRVInternal.h"
41 #include "OCLUtil.h"
42 #include "OCLTypeToSPIRV.h"
43 
44 #include "llvm/ADT/StringSwitch.h"
45 #include "llvm/IR/InstVisitor.h"
46 #include "llvm/IR/Instructions.h"
47 #include "llvm/IR/Instruction.h"
48 #include "llvm/IR/IRBuilder.h"
49 #include "llvm/IR/Verifier.h"
50 #include "llvm/Pass.h"
51 #include "llvm/PassSupport.h"
52 #include "llvm/Support/Debug.h"
53 #include "llvm/Support/raw_ostream.h"
54 
55 #include <set>
56 
57 using namespace llvm;
58 using namespace SPIRV;
59 using namespace OCLUtil;
60 
61 namespace SPIRV {
62 static size_t
getOCLCpp11AtomicMaxNumOps(StringRef Name)63 getOCLCpp11AtomicMaxNumOps(StringRef Name) {
64   return StringSwitch<size_t>(Name)
65       .Cases("load", "flag_test_and_set", "flag_clear", 3)
66       .Cases("store", "exchange",  4)
67       .StartsWith("compare_exchange", 6)
68       .StartsWith("fetch", 4)
69       .Default(0);
70 }
71 
72 class OCL20ToSPIRV: public ModulePass,
73   public InstVisitor<OCL20ToSPIRV> {
74 public:
OCL20ToSPIRV()75   OCL20ToSPIRV():ModulePass(ID), M(nullptr), Ctx(nullptr), CLVer(0) {
76     initializeOCL20ToSPIRVPass(*PassRegistry::getPassRegistry());
77   }
78   virtual bool runOnModule(Module &M);
79 
getAnalysisUsage(AnalysisUsage & AU) const80   void getAnalysisUsage(AnalysisUsage &AU) const {
81     AU.addRequired<OCLTypeToSPIRV>();
82   }
83 
84   virtual void visitCallInst(CallInst &CI);
85 
86   /// Transform barrier/work_group_barrier/sub_group_barrier
87   ///     to __spirv_ControlBarrier.
88   /// barrier(flag) =>
89   ///   __spirv_ControlBarrier(workgroup, workgroup, map(flag))
90   /// work_group_barrier(scope, flag) =>
91   ///   __spirv_ControlBarrier(workgroup, map(scope), map(flag))
92   /// sub_group_barrier(scope, flag) =>
93   ///   __spirv_ControlBarrier(subgroup, map(scope), map(flag))
94   void visitCallBarrier(CallInst *CI);
95 
96   /// Erase useless convert functions.
97   /// \return true if the call instruction is erased.
98   bool eraseUselessConvert(CallInst *Call, const std::string &MangledName,
99       const std::string &DeMangledName);
100 
101   /// Transform convert_ to
102   ///   __spirv_{CastOpName}_R{TargeTyName}{_sat}{_rt[p|n|z|e]}
103   void visitCallConvert(CallInst *CI, StringRef MangledName,
104     const std::string &DemangledName);
105 
106   /// Transform async_work_group{_strided}_copy.
107   /// async_work_group_copy(dst, src, n, event)
108   ///   => async_work_group_strided_copy(dst, src, n, 1, event)
109   /// async_work_group_strided_copy(dst, src, n, stride, event)
110   ///   => __spirv_AsyncGroupCopy(ScopeWorkGroup, dst, src, n, stride, event)
111   void visitCallAsyncWorkGroupCopy(CallInst *CI,
112       const std::string &DemangledName);
113 
114   /// Transform OCL builtin function to SPIR-V builtin function.
115   void transBuiltin(CallInst *CI, OCLBuiltinTransInfo &Info);
116 
117   /// Transform OCL work item builtin functions to SPIR-V builtin variables.
118   void transWorkItemBuiltinsToVariables();
119 
120   /// Transform atomic_work_item_fence/mem_fence to __spirv_MemoryBarrier.
121   /// func(flag, order, scope) =>
122   ///   __spirv_MemoryBarrier(map(scope), map(flag)|map(order))
123   void transMemoryBarrier(CallInst *CI, AtomicWorkItemFenceLiterals);
124 
125   /// Transform all to __spirv_Op(All|Any).  Note that the types mismatch so
126   // some extra code is emitted to convert between the two.
127   void visitCallAllAny(spv::Op OC, CallInst *CI);
128 
129   /// Transform atomic_* to __spirv_Atomic*.
130   /// atomic_x(ptr_arg, args, order, scope) =>
131   ///   __spirv_AtomicY(ptr_arg, map(order), map(scope), args)
132   void transAtomicBuiltin(CallInst *CI, OCLBuiltinTransInfo &Info);
133 
134   /// Transform atomic_work_item_fence to __spirv_MemoryBarrier.
135   /// atomic_work_item_fence(flag, order, scope) =>
136   ///   __spirv_MemoryBarrier(map(scope), map(flag)|map(order))
137   void visitCallAtomicWorkItemFence(CallInst *CI);
138 
139   /// Transform atomic_compare_exchange call.
140   /// In atomic_compare_exchange, the expected value parameter is a pointer.
141   /// However in SPIR-V it is a value. The transformation adds a load
142   /// instruction, result of which is passed to atomic_compare_exchange as
143   /// argument.
144   /// The transformation adds a store instruction after the call, to update the
145   /// value in expected with the value pointed to by object. Though, it is not
146   /// necessary in case they are equal, this approach makes result code simpler.
147   /// Also ICmp instruction is added, because the call must return result of
148   /// comparison.
149   /// \returns the call instruction of atomic_compare_exchange_strong.
150   CallInst *visitCallAtomicCmpXchg(CallInst *CI,
151       const std::string &DemangledName);
152 
153   /// Transform atomic_init.
154   /// atomic_init(p, x) => store p, x
155   void visitCallAtomicInit(CallInst *CI);
156 
157   /// Transform legacy OCL 1.x atomic builtins to SPIR-V builtins for extensions
158   ///   cl_khr_int64_base_atomics
159   ///   cl_khr_int64_extended_atomics
160   /// Do nothing if the called function is not a legacy atomic builtin.
161   void visitCallAtomicLegacy(CallInst *CI, StringRef MangledName,
162     const std::string &DemangledName);
163 
164   /// Transform OCL 2.0 C++11 atomic builtins to SPIR-V builtins.
165   /// Do nothing if the called function is not a C++11 atomic builtin.
166   void visitCallAtomicCpp11(CallInst *CI, StringRef MangledName,
167     const std::string &DemangledName);
168 
169   /// Transform OCL builtin function to SPIR-V builtin function.
170   /// Assuming there is a simple name mapping without argument changes.
171   /// Should be called at last.
172   void visitCallBuiltinSimple(CallInst *CI, StringRef MangledName,
173     const std::string &DemangledName);
174 
175   /// Transform get_image_{width|height|depth|dim}.
176   /// get_image_xxx(...) =>
177   ///   dimension = __spirv_ImageQuerySizeLod_R{ReturnType}(...);
178   ///   return dimension.{x|y|z};
179   void visitCallGetImageSize(CallInst *CI, StringRef MangledName,
180     const std::string &DemangledName);
181 
182   /// Transform {work|sub}_group_x =>
183   ///   __spirv_{OpName}
184   ///
185   /// Special handling of work_group_broadcast.
186   ///   work_group_broadcast(a, x, y, z)
187   ///     =>
188   ///   __spirv_GroupBroadcast(a, vec3(x, y, z))
189 
190   void visitCallGroupBuiltin(CallInst *CI, StringRef MangledName,
191     const std::string &DemangledName);
192 
193   /// Transform mem_fence to __spirv_MemoryBarrier.
194   /// mem_fence(flag) => __spirv_MemoryBarrier(Workgroup, map(flag))
195   void visitCallMemFence(CallInst *CI);
196 
197   void visitCallNDRange(CallInst *CI, const std::string &DemangledName);
198 
199   /// Transform OCL pipe builtin function to SPIR-V pipe builtin function.
200   void visitCallPipeBuiltin(CallInst *CI, StringRef MangledName,
201     const std::string &DemangledName);
202 
203   /// Transform read_image with sampler arguments.
204   /// read_image(image, sampler, ...) =>
205   ///   sampled_image = __spirv_SampledImage(image, sampler);
206   ///   return __spirv_ImageSampleExplicitLod_R{ReturnType}(sampled_image, ...);
207   void visitCallReadImageWithSampler(CallInst *CI, StringRef MangledName,
208       const std::string &DemangledName);
209 
210   /// Transform read_image with msaa image arguments.
211   /// Sample argument must be acoded as Image Operand.
212   void visitCallReadImageMSAA(CallInst *CI, StringRef MangledName,
213                               const std::string &DemangledName);
214 
215   /// Transform {read|write}_image without sampler arguments.
216   void visitCallReadWriteImage(CallInst *CI, StringRef MangledName,
217       const std::string &DemangledName);
218 
219   /// Transform to_{global|local|private}.
220   ///
221   /// T* a = ...;
222   /// addr T* b = to_addr(a);
223   ///   =>
224   /// i8* x = cast<i8*>(a);
225   /// addr i8* y = __spirv_GenericCastToPtr_ToAddr(x);
226   /// addr T* b = cast<addr T*>(y);
227   void visitCallToAddr(CallInst *CI, StringRef MangledName,
228       const std::string &DemangledName);
229 
230   /// Transform return type of relatinal built-in functions like isnan, isfinite
231   /// to boolean values.
232   void visitCallRelational(CallInst *CI, const std::string &DemangledName);
233 
234   /// Transform vector load/store functions to SPIR-V extended builtin
235   ///   functions
236   /// {vload|vstore{a}}{_half}{n}{_rte|_rtz|_rtp|_rtn} =>
237   ///   __spirv_ocl_{ExtendedInstructionOpCodeName}__R{ReturnType}
238   void visitCallVecLoadStore(CallInst *CI, StringRef MangledName,
239       const std::string &DemangledName);
240 
241   /// Transforms get_mem_fence built-in to SPIR-V function and aligns result values with SPIR 1.2.
242   /// get_mem_fence(ptr) => __spirv_GenericPtrMemSemantics
243   /// GenericPtrMemSemantics valid values are 0x100, 0x200 and 0x300, where is
244   /// SPIR 1.2 defines them as 0x1, 0x2 and 0x3, so this function adjusts
245   /// GenericPtrMemSemantics results to SPIR 1.2 values.
246   void visitCallGetFence(CallInst *CI, StringRef MangledName, const std::string& DemangledName);
247 
248   /// Transforms OpDot instructions with a scalar type to a fmul instruction
249   void visitCallDot(CallInst *CI);
250 
251   /// Fixes for built-in functions with vector+scalar arguments that are
252   /// translated to the SPIR-V instructions where all arguments must have the
253   /// same type.
254   void visitCallScalToVec(CallInst *CI, StringRef MangledName,
255                           const std::string &DemangledName);
256 
257   /// Transform get_image_channel_{order|data_type} built-in functions to
258   ///   __spirv_ocl_{ImageQueryOrder|ImageQueryFormat}
259   void visitCallGetImageChannel(CallInst *CI, StringRef MangledName,
260                                 const std::string &DemangledName,
261                                 unsigned int Offset);
262 
visitDbgInfoIntrinsic(DbgInfoIntrinsic & I)263   void visitDbgInfoIntrinsic(DbgInfoIntrinsic &I){
264     I.dropAllReferences();
265     I.eraseFromParent();
266   }
267   static char ID;
268 private:
269   Module *M;
270   LLVMContext *Ctx;
271   unsigned CLVer;                   /// OpenCL version as major*10+minor
272   std::set<Value *> ValuesToDelete;
273 
addInt32(int I)274   ConstantInt *addInt32(int I) {
275     return getInt32(M, I);
276   }
addSizet(uint64_t I)277   ConstantInt *addSizet(uint64_t I) {
278     return getSizet(M, I);
279   }
280 
281   /// Get vector width from OpenCL vload* function name.
getVecLoadWidth(const std::string & DemangledName)282   SPIRVWord getVecLoadWidth(const std::string& DemangledName) {
283     SPIRVWord Width = 0;
284     if (DemangledName == "vloada_half")
285       Width = 1;
286     else {
287       unsigned Loc = 5;
288       if (DemangledName.find("vload_half") == 0)
289         Loc = 10;
290       else if (DemangledName.find("vloada_half") == 0)
291         Loc = 11;
292 
293       std::stringstream SS(DemangledName.substr(Loc));
294       SS >> Width;
295     }
296     return Width;
297   }
298 
299   /// Transform OpenCL vload/vstore function name.
transVecLoadStoreName(std::string & DemangledName,const std::string & Stem,bool AlwaysN)300   void transVecLoadStoreName(std::string& DemangledName,
301       const std::string &Stem, bool AlwaysN) {
302     auto HalfStem = Stem + "_half";
303     auto HalfStemR = HalfStem + "_r";
304     if (!AlwaysN && DemangledName == HalfStem)
305       return;
306     if (!AlwaysN && DemangledName.find(HalfStemR) == 0) {
307       DemangledName = HalfStemR;
308       return;
309     }
310     if (DemangledName.find(HalfStem) == 0) {
311       auto OldName = DemangledName;
312       DemangledName = HalfStem + "n";
313       if (OldName.find("_r") != std::string::npos)
314         DemangledName += "_r";
315       return;
316     }
317     if (DemangledName.find(Stem) == 0) {
318       DemangledName = Stem + "n";
319       return;
320     }
321   }
322 
323 };
324 
325 char OCL20ToSPIRV::ID = 0;
326 
327 bool
runOnModule(Module & Module)328 OCL20ToSPIRV::runOnModule(Module& Module) {
329   M = &Module;
330   Ctx = &M->getContext();
331   auto Src = getSPIRVSource(&Module);
332   if (std::get<0>(Src) != spv::SourceLanguageOpenCL_C)
333     return false;
334 
335   CLVer = std::get<1>(Src);
336   if (CLVer > kOCLVer::CL20)
337     return false;
338 
339   DEBUG(dbgs() << "Enter OCL20ToSPIRV:\n");
340 
341   transWorkItemBuiltinsToVariables();
342 
343   visit(*M);
344 
345   for (auto &I:ValuesToDelete)
346     if (auto Inst = dyn_cast<Instruction>(I))
347       Inst->eraseFromParent();
348   for (auto &I:ValuesToDelete)
349     if (auto GV = dyn_cast<GlobalValue>(I))
350       GV->eraseFromParent();
351 
352   DEBUG(dbgs() << "After OCL20ToSPIRV:\n" << *M);
353 
354   std::string Err;
355   raw_string_ostream ErrorOS(Err);
356   if (verifyModule(*M, &ErrorOS)){
357     DEBUG(errs() << "Fails to verify module: " << ErrorOS.str());
358   }
359   return true;
360 }
361 
362 // The order of handling OCL builtin functions is important.
363 // Workgroup functions need to be handled before pipe functions since
364 // there are functions fall into both categories.
365 void
visitCallInst(CallInst & CI)366 OCL20ToSPIRV::visitCallInst(CallInst& CI) {
367   DEBUG(dbgs() << "[visistCallInst] " << CI << '\n');
368   auto F = CI.getCalledFunction();
369   if (!F)
370     return;
371 
372   auto MangledName = F->getName();
373   std::string DemangledName;
374   if (!oclIsBuiltin(MangledName, &DemangledName))
375     return;
376 
377   DEBUG(dbgs() << "DemangledName: " << DemangledName << '\n');
378   if (DemangledName.find(kOCLBuiltinName::NDRangePrefix) == 0) {
379     visitCallNDRange(&CI, DemangledName);
380     return;
381   }
382   if (DemangledName == kOCLBuiltinName::All) {
383       visitCallAllAny(OpAll, &CI);
384       return;
385   }
386   if (DemangledName == kOCLBuiltinName::Any) {
387       visitCallAllAny(OpAny, &CI);
388       return;
389   }
390   if (DemangledName.find(kOCLBuiltinName::AsyncWorkGroupCopy) == 0 ||
391       DemangledName.find(kOCLBuiltinName::AsyncWorkGroupStridedCopy) == 0) {
392     visitCallAsyncWorkGroupCopy(&CI, DemangledName);
393     return;
394   }
395   if (DemangledName.find(kOCLBuiltinName::AtomicPrefix) == 0 ||
396       DemangledName.find(kOCLBuiltinName::AtomPrefix) == 0) {
397     auto PCI = &CI;
398     if (DemangledName == kOCLBuiltinName::AtomicInit) {
399       visitCallAtomicInit(PCI);
400       return;
401     }
402     if (DemangledName == kOCLBuiltinName::AtomicWorkItemFence) {
403       visitCallAtomicWorkItemFence(PCI);
404       return;
405     }
406     if (DemangledName == kOCLBuiltinName::AtomicCmpXchgWeak ||
407         DemangledName == kOCLBuiltinName::AtomicCmpXchgStrong ||
408         DemangledName == kOCLBuiltinName::AtomicCmpXchgWeakExplicit ||
409         DemangledName == kOCLBuiltinName::AtomicCmpXchgStrongExplicit) {
410       assert(CLVer == kOCLVer::CL20 && "Wrong version of OpenCL");
411       PCI = visitCallAtomicCmpXchg(PCI, DemangledName);
412     }
413     visitCallAtomicLegacy(PCI, MangledName, DemangledName);
414     visitCallAtomicCpp11(PCI, MangledName, DemangledName);
415     return;
416   }
417   if (DemangledName.find(kOCLBuiltinName::ConvertPrefix) == 0) {
418     visitCallConvert(&CI, MangledName, DemangledName);
419     return;
420   }
421   if (DemangledName == kOCLBuiltinName::GetImageWidth ||
422       DemangledName == kOCLBuiltinName::GetImageHeight ||
423       DemangledName == kOCLBuiltinName::GetImageDepth ||
424       DemangledName == kOCLBuiltinName::GetImageDim   ||
425       DemangledName == kOCLBuiltinName::GetImageArraySize) {
426     visitCallGetImageSize(&CI, MangledName, DemangledName);
427     return;
428   }
429   if ((DemangledName.find(kOCLBuiltinName::WorkGroupPrefix) == 0 &&
430       DemangledName != kOCLBuiltinName::WorkGroupBarrier) ||
431       DemangledName == kOCLBuiltinName::WaitGroupEvent ||
432       (DemangledName.find(kOCLBuiltinName::SubGroupPrefix) == 0 &&
433        DemangledName != kOCLBuiltinName::SubGroupBarrier)) {
434     visitCallGroupBuiltin(&CI, MangledName, DemangledName);
435     return;
436   }
437   if (DemangledName.find(kOCLBuiltinName::Pipe) != std::string::npos) {
438     visitCallPipeBuiltin(&CI, MangledName, DemangledName);
439     return;
440   }
441   if (DemangledName == kOCLBuiltinName::MemFence) {
442     visitCallMemFence(&CI);
443     return;
444   }
445   if (DemangledName.find(kOCLBuiltinName::ReadImage) == 0) {
446     if (MangledName.find(kMangledName::Sampler) != StringRef::npos) {
447       visitCallReadImageWithSampler(&CI, MangledName, DemangledName);
448       return;
449     }
450     if (MangledName.find("msaa") != StringRef::npos) {
451       visitCallReadImageMSAA(&CI, MangledName, DemangledName);
452       return;
453     }
454   }
455   if (DemangledName.find(kOCLBuiltinName::ReadImage) == 0 ||
456       DemangledName.find(kOCLBuiltinName::WriteImage) == 0) {
457     visitCallReadWriteImage(&CI, MangledName, DemangledName);
458     return;
459   }
460   if (DemangledName == kOCLBuiltinName::ToGlobal ||
461       DemangledName == kOCLBuiltinName::ToLocal ||
462       DemangledName == kOCLBuiltinName::ToPrivate) {
463     visitCallToAddr(&CI, MangledName, DemangledName);
464     return;
465   }
466   if (DemangledName.find(kOCLBuiltinName::VLoadPrefix) == 0 ||
467       DemangledName.find(kOCLBuiltinName::VStorePrefix) == 0) {
468     visitCallVecLoadStore(&CI, MangledName, DemangledName);
469     return;
470   }
471   if (DemangledName == kOCLBuiltinName::IsFinite ||
472       DemangledName == kOCLBuiltinName::IsInf ||
473       DemangledName == kOCLBuiltinName::IsNan ||
474       DemangledName == kOCLBuiltinName::IsNormal ||
475       DemangledName == kOCLBuiltinName::Signbit) {
476     visitCallRelational(&CI, DemangledName);
477     return;
478   }
479   if (DemangledName == kOCLBuiltinName::WorkGroupBarrier ||
480       DemangledName == kOCLBuiltinName::Barrier) {
481     visitCallBarrier(&CI);
482     return;
483   }
484   if (DemangledName == kOCLBuiltinName::GetFence) {
485     visitCallGetFence(&CI, MangledName, DemangledName);
486     return;
487   }
488   if (DemangledName == kOCLBuiltinName::Dot &&
489       !(CI.getOperand(0)->getType()->isVectorTy())) {
490     visitCallDot(&CI);
491     return;
492   }
493   if (DemangledName == kOCLBuiltinName::FMin ||
494       DemangledName == kOCLBuiltinName::FMax ||
495       DemangledName == kOCLBuiltinName::Min ||
496       DemangledName == kOCLBuiltinName::Max ||
497       DemangledName == kOCLBuiltinName::Step ||
498       DemangledName == kOCLBuiltinName::SmoothStep ||
499       DemangledName == kOCLBuiltinName::Clamp ||
500       DemangledName == kOCLBuiltinName::Mix) {
501     visitCallScalToVec(&CI, MangledName, DemangledName);
502     return;
503   }
504   if (DemangledName == kOCLBuiltinName::GetImageChannelDataType) {
505     visitCallGetImageChannel(&CI, MangledName, DemangledName,
506                              OCLImageChannelDataTypeOffset);
507     return;
508   }
509   if (DemangledName == kOCLBuiltinName::GetImageChannelOrder) {
510     visitCallGetImageChannel(&CI, MangledName, DemangledName,
511                              OCLImageChannelOrderOffset);
512     return;
513   }
514   visitCallBuiltinSimple(&CI, MangledName, DemangledName);
515 }
516 
517 void
visitCallNDRange(CallInst * CI,const std::string & DemangledName)518 OCL20ToSPIRV::visitCallNDRange(CallInst *CI,
519     const std::string &DemangledName) {
520   assert(DemangledName.find(kOCLBuiltinName::NDRangePrefix) == 0);
521   std::string lenStr = DemangledName.substr(8, 1);
522   auto Len = atoi(lenStr.c_str());
523   assert (Len >= 1 && Len <= 3);
524   // SPIR-V ndrange structure requires 3 members in the following order:
525   //   global work offset
526   //   global work size
527   //   local work size
528   // The arguments need to add missing members.
529   AttributeSet Attrs = CI->getCalledFunction()->getAttributes();
530   mutateCallInstSPIRV(M, CI, [=](CallInst *, std::vector<Value *> &Args){
531     for (size_t I = 1, E = Args.size(); I != E; ++I)
532       Args[I] = getScalarOrArray(Args[I], Len, CI);
533     switch (Args.size()) {
534     case 2: {
535       // Has global work size.
536       auto T = Args[1]->getType();
537       auto C = getScalarOrArrayConstantInt(CI, T, Len, 0);
538       Args.push_back(C);
539       Args.push_back(C);
540     }
541       break;
542     case 3: {
543       // Has global and local work size.
544       auto T = Args[1]->getType();
545       Args.push_back(getScalarOrArrayConstantInt(CI, T, Len, 0));
546     }
547       break;
548     case 4: {
549       // Move offset arg to the end
550       auto OffsetPos = Args.begin() + 1;
551       Value* OffsetVal = *OffsetPos;
552       Args.erase(OffsetPos);
553       Args.push_back(OffsetVal);
554     }
555       break;
556     default:
557       assert(0 && "Invalid number of arguments");
558     }
559     // Translate ndrange_ND into differently named SPIR-V decorated functions because
560     // they have array arugments of different dimension which mangled the same way.
561     return getSPIRVFuncName(OpBuildNDRange, "_" + lenStr + "D");
562   }, &Attrs);
563 }
564 
565 void
visitCallAsyncWorkGroupCopy(CallInst * CI,const std::string & DemangledName)566 OCL20ToSPIRV::visitCallAsyncWorkGroupCopy(CallInst* CI,
567     const std::string &DemangledName) {
568   AttributeSet Attrs = CI->getCalledFunction()->getAttributes();
569   mutateCallInstSPIRV(M, CI, [=](CallInst *, std::vector<Value *> &Args){
570     if (DemangledName == OCLUtil::kOCLBuiltinName::AsyncWorkGroupCopy) {
571       Args.insert(Args.begin()+3, addSizet(1));
572     }
573     Args.insert(Args.begin(), addInt32(ScopeWorkgroup));
574     return getSPIRVFuncName(OpGroupAsyncCopy);
575   }, &Attrs);
576 }
577 
578 CallInst *
visitCallAtomicCmpXchg(CallInst * CI,const std::string & DemangledName)579 OCL20ToSPIRV::visitCallAtomicCmpXchg(CallInst* CI,
580     const std::string& DemangledName) {
581   AttributeSet Attrs = CI->getCalledFunction()->getAttributes();
582   Value *Expected = nullptr;
583   CallInst *NewCI = nullptr;
584   mutateCallInstOCL(M, CI, [&](CallInst * CI, std::vector<Value *> &Args,
585       Type *&RetTy){
586     Expected = Args[1]; // temporary save second argument.
587     Args[1] = new LoadInst(Args[1], "exp", false, CI);
588     RetTy = Args[2]->getType();
589     assert(Args[0]->getType()->getPointerElementType()->isIntegerTy() &&
590       Args[1]->getType()->isIntegerTy() && Args[2]->getType()->isIntegerTy() &&
591       "In SPIR-V 1.0 arguments of OpAtomicCompareExchange must be "
592       "an integer type scalars");
593     return kOCLBuiltinName::AtomicCmpXchgStrong;
594   },
595   [&](CallInst *NCI)->Instruction * {
596     NewCI = NCI;
597     Instruction* Store = new StoreInst(NCI, Expected, NCI->getNextNode());
598     return new ICmpInst(Store->getNextNode(), CmpInst::ICMP_EQ, NCI,
599                         NCI->getArgOperand(1));
600   },
601   &Attrs);
602   return NewCI;
603 }
604 
605 void
visitCallAtomicInit(CallInst * CI)606 OCL20ToSPIRV::visitCallAtomicInit(CallInst* CI) {
607   auto ST = new StoreInst(CI->getArgOperand(1), CI->getArgOperand(0), CI);
608   ST->takeName(CI);
609   CI->dropAllReferences();
610   CI->eraseFromParent();
611 }
612 
613 void
visitCallAllAny(spv::Op OC,CallInst * CI)614 OCL20ToSPIRV::visitCallAllAny(spv::Op OC, CallInst* CI) {
615   AttributeSet Attrs = CI->getCalledFunction()->getAttributes();
616 
617   auto Args = getArguments(CI);
618   assert(Args.size() == 1);
619 
620   auto *ArgTy = Args[0]->getType();
621   auto Zero = Constant::getNullValue(Args[0]->getType());
622 
623   auto *Cmp = CmpInst::Create(CmpInst::ICmp, CmpInst::ICMP_SLT, Args[0], Zero,
624                                "cast", CI);
625 
626   if (!isa<VectorType>(ArgTy)) {
627     auto *Cast = CastInst::CreateZExtOrBitCast(Cmp, Type::getInt32Ty(*Ctx),
628                                                 "", Cmp->getNextNode());
629     CI->replaceAllUsesWith(Cast);
630     CI->eraseFromParent();
631   } else {
632     mutateCallInstSPIRV(
633         M, CI,
634         [&](CallInst *, std::vector<Value *> &Args, Type *&Ret) {
635           Args[0] = Cmp;
636           Ret = Type::getInt1Ty(*Ctx);
637 
638           return getSPIRVFuncName(OC);
639         },
640         [&](CallInst *CI) -> Instruction * {
641           return CastInst::CreateZExtOrBitCast(CI, Type::getInt32Ty(*Ctx), "",
642                                                CI->getNextNode());
643         },
644         &Attrs);
645   }
646 }
647 
648 void
visitCallAtomicWorkItemFence(CallInst * CI)649 OCL20ToSPIRV::visitCallAtomicWorkItemFence(CallInst* CI) {
650   transMemoryBarrier(CI, getAtomicWorkItemFenceLiterals(CI));
651 }
652 
653 void
visitCallMemFence(CallInst * CI)654 OCL20ToSPIRV::visitCallMemFence(CallInst* CI) {
655   transMemoryBarrier(CI, std::make_tuple(
656       cast<ConstantInt>(CI->getArgOperand(0))->getZExtValue(),
657       OCLMO_relaxed,
658       OCLMS_work_group));
659 }
660 
transMemoryBarrier(CallInst * CI,AtomicWorkItemFenceLiterals Lit)661 void OCL20ToSPIRV::transMemoryBarrier(CallInst* CI,
662     AtomicWorkItemFenceLiterals Lit) {
663   AttributeSet Attrs = CI->getCalledFunction()->getAttributes();
664   mutateCallInstSPIRV(M, CI, [=](CallInst *, std::vector<Value *> &Args){
665     Args.resize(2);
666     Args[0] = addInt32(map<Scope>(std::get<2>(Lit)));
667     Args[1] = addInt32(mapOCLMemSemanticToSPIRV(std::get<0>(Lit),
668         std::get<1>(Lit)));
669     return getSPIRVFuncName(OpMemoryBarrier);
670   }, &Attrs);
671 }
672 
673 void
visitCallAtomicLegacy(CallInst * CI,StringRef MangledName,const std::string & DemangledName)674 OCL20ToSPIRV::visitCallAtomicLegacy(CallInst* CI,
675     StringRef MangledName, const std::string& DemangledName) {
676   StringRef Stem = DemangledName;
677   if (Stem.startswith("atom_"))
678     Stem = Stem.drop_front(strlen("atom_"));
679   else if (Stem.startswith("atomic_"))
680     Stem = Stem.drop_front(strlen("atomic_"));
681   else
682     return;
683 
684   std::string Sign;
685   std::string Postfix;
686   std::string Prefix;
687   if (Stem == "add" ||
688       Stem == "sub" ||
689       Stem == "and" ||
690       Stem == "or" ||
691       Stem == "xor" ||
692       Stem == "min" ||
693       Stem == "max") {
694     if ((Stem == "min" || Stem == "max") &&
695          isMangledTypeUnsigned(MangledName.back()))
696       Sign = 'u';
697     Prefix = "fetch_";
698     Postfix = "_explicit";
699   } else if (Stem == "xchg") {
700     Stem = "exchange";
701     Postfix = "_explicit";
702   }
703   else if (Stem == "cmpxchg") {
704     Stem = "compare_exchange_strong";
705     Postfix = "_explicit";
706   }
707   else if (Stem == "inc" ||
708            Stem == "dec") {
709     // do nothing
710   } else
711     return;
712 
713   OCLBuiltinTransInfo Info;
714   Info.UniqName = "atomic_" + Prefix + Sign + Stem.str() + Postfix;
715   std::vector<int> PostOps;
716   PostOps.push_back(OCLLegacyAtomicMemOrder);
717   if (Stem.startswith("compare_exchange"))
718     PostOps.push_back(OCLLegacyAtomicMemOrder);
719   PostOps.push_back(OCLLegacyAtomicMemScope);
720 
721   Info.PostProc = [=](std::vector<Value *> &Ops){
722     for (auto &I:PostOps){
723       Ops.push_back(addInt32(I));
724     }
725   };
726   transAtomicBuiltin(CI, Info);
727 }
728 
729 void
visitCallAtomicCpp11(CallInst * CI,StringRef MangledName,const std::string & DemangledName)730 OCL20ToSPIRV::visitCallAtomicCpp11(CallInst* CI,
731     StringRef MangledName, const std::string& DemangledName) {
732   StringRef Stem = DemangledName;
733   if (Stem.startswith("atomic_"))
734     Stem = Stem.drop_front(strlen("atomic_"));
735   else
736     return;
737 
738   std::string NewStem = Stem;
739   std::vector<int> PostOps;
740   if (Stem.startswith("store") ||
741       Stem.startswith("load") ||
742       Stem.startswith("exchange") ||
743       Stem.startswith("compare_exchange") ||
744       Stem.startswith("fetch") ||
745       Stem.startswith("flag")) {
746     if ((Stem.startswith("fetch_min") ||
747         Stem.startswith("fetch_max")) &&
748         containsUnsignedAtomicType(MangledName))
749       NewStem.insert(NewStem.begin() + strlen("fetch_"), 'u');
750 
751     if (!Stem.endswith("_explicit")) {
752       NewStem = NewStem + "_explicit";
753       PostOps.push_back(OCLMO_seq_cst);
754       if (Stem.startswith("compare_exchange"))
755         PostOps.push_back(OCLMO_seq_cst);
756       PostOps.push_back(OCLMS_device);
757     } else {
758       auto MaxOps = getOCLCpp11AtomicMaxNumOps(
759           Stem.drop_back(strlen("_explicit")));
760       if (CI->getNumArgOperands() < MaxOps)
761         PostOps.push_back(OCLMS_device);
762     }
763   } else if (Stem == "work_item_fence") {
764     // do nothing
765   } else
766     return;
767 
768   OCLBuiltinTransInfo Info;
769   Info.UniqName = std::string("atomic_") + NewStem;
770   Info.PostProc = [=](std::vector<Value *> &Ops){
771     for (auto &I:PostOps){
772       Ops.push_back(addInt32(I));
773     }
774   };
775 
776   transAtomicBuiltin(CI, Info);
777 }
778 
779 void
transAtomicBuiltin(CallInst * CI,OCLBuiltinTransInfo & Info)780 OCL20ToSPIRV::transAtomicBuiltin(CallInst* CI,
781     OCLBuiltinTransInfo& Info) {
782   AttributeSet Attrs = CI->getCalledFunction()->getAttributes();
783   mutateCallInstSPIRV(M, CI, [=](CallInst * CI, std::vector<Value *> &Args){
784     Info.PostProc(Args);
785     // Order of args in OCL20:
786     // object, 0-2 other args, 1-2 order, scope
787     const size_t NumOrder = getAtomicBuiltinNumMemoryOrderArgs(Info.UniqName);
788     const size_t ArgsCount = Args.size();
789     const size_t ScopeIdx = ArgsCount - 1;
790     const size_t OrderIdx = ScopeIdx - NumOrder;
791     Args[ScopeIdx] = mapUInt(M, cast<ConstantInt>(Args[ScopeIdx]),
792         [](unsigned I){
793       return map<Scope>(static_cast<OCLScopeKind>(I));
794     });
795     for (size_t I = 0; I < NumOrder; ++I)
796       Args[OrderIdx + I] = mapUInt(M, cast<ConstantInt>(Args[OrderIdx + I]),
797           [](unsigned Ord) {
798       return mapOCLMemSemanticToSPIRV(0, static_cast<OCLMemOrderKind>(Ord));
799     });
800     // Order of args in SPIR-V:
801     // object, scope, 1-2 order, 0-2 other args
802     std::swap(Args[1], Args[ScopeIdx]);
803     if(OrderIdx > 2) {
804       // For atomic_compare_exchange the swap above puts Comparator/Expected
805       // argument just where it should be, so don't move the last argument then.
806       int offset = Info.UniqName.find("atomic_compare_exchange") == 0 ? 1 : 0;
807       std::rotate(Args.begin() + 2, Args.begin() + OrderIdx,
808                   Args.end() - offset);
809     }
810     return getSPIRVFuncName(OCLSPIRVBuiltinMap::map(Info.UniqName));
811   }, &Attrs);
812 }
813 
814 void
visitCallBarrier(CallInst * CI)815 OCL20ToSPIRV::visitCallBarrier(CallInst* CI) {
816   auto Lit = getBarrierLiterals(CI);
817   AttributeSet Attrs = CI->getCalledFunction()->getAttributes();
818   mutateCallInstSPIRV(M, CI, [=](CallInst *, std::vector<Value *> &Args){
819     Args.resize(3);
820     Args[0] = addInt32(map<Scope>(std::get<2>(Lit)));
821     Args[1] = addInt32(map<Scope>(std::get<1>(Lit)));
822     Args[2] = addInt32(mapOCLMemFenceFlagToSPIRV(std::get<0>(Lit)));
823     return getSPIRVFuncName(OpControlBarrier);
824   }, &Attrs);
825 }
826 
visitCallConvert(CallInst * CI,StringRef MangledName,const std::string & DemangledName)827 void OCL20ToSPIRV::visitCallConvert(CallInst* CI,
828     StringRef MangledName, const std::string& DemangledName) {
829   if (eraseUselessConvert(CI, MangledName, DemangledName))
830     return;
831   Op OC = OpNop;
832   auto TargetTy = CI->getType();
833   auto SrcTy = CI->getArgOperand(0)->getType();
834   if (isa<VectorType>(TargetTy))
835     TargetTy = TargetTy->getVectorElementType();
836   if (isa<VectorType>(SrcTy))
837     SrcTy = SrcTy->getVectorElementType();
838   auto IsTargetInt = isa<IntegerType>(TargetTy);
839 
840   std::string TargetTyName = DemangledName.substr(
841       strlen(kOCLBuiltinName::ConvertPrefix));
842   auto FirstUnderscoreLoc = TargetTyName.find('_');
843   if (FirstUnderscoreLoc != std::string::npos)
844     TargetTyName = TargetTyName.substr(0, FirstUnderscoreLoc);
845   TargetTyName = std::string("_R") + TargetTyName;
846 
847   std::string Sat = DemangledName.find("_sat") != std::string::npos ?
848       "_sat" : "";
849   auto TargetSigned = DemangledName[8] != 'u';
850   if (isa<IntegerType>(SrcTy)) {
851     bool Signed = isLastFuncParamSigned(MangledName);
852     if (IsTargetInt) {
853       if (!Sat.empty() && TargetSigned != Signed) {
854         OC = Signed ? OpSatConvertSToU : OpSatConvertUToS;
855         Sat = "";
856       } else
857         OC = Signed ? OpSConvert : OpUConvert;
858     } else
859       OC = Signed ? OpConvertSToF : OpConvertUToF;
860   } else {
861     if (IsTargetInt) {
862       OC = TargetSigned ? OpConvertFToS : OpConvertFToU;
863     } else
864       OC = OpFConvert;
865   }
866   auto Loc = DemangledName.find("_rt");
867   std::string Rounding;
868   if (Loc != std::string::npos &&
869       !(isa<IntegerType>(SrcTy) && IsTargetInt)) {
870     Rounding = DemangledName.substr(Loc, 4);
871   }
872   AttributeSet Attrs = CI->getCalledFunction()->getAttributes();
873   mutateCallInstSPIRV(M, CI, [=](CallInst *, std::vector<Value *> &Args){
874     return getSPIRVFuncName(OC, TargetTyName + Sat + Rounding);
875   }, &Attrs);
876 }
877 
visitCallGroupBuiltin(CallInst * CI,StringRef MangledName,const std::string & OrigDemangledName)878 void OCL20ToSPIRV::visitCallGroupBuiltin(CallInst* CI,
879     StringRef MangledName, const std::string& OrigDemangledName) {
880   auto F = CI->getCalledFunction();
881   std::vector<int> PreOps;
882   std::string DemangledName = OrigDemangledName;
883 
884   if (DemangledName == kOCLBuiltinName::WorkGroupBarrier)
885     return;
886   if (DemangledName == kOCLBuiltinName::WaitGroupEvent) {
887     PreOps.push_back(ScopeWorkgroup);
888   } else if (DemangledName.find(kOCLBuiltinName::WorkGroupPrefix) == 0) {
889     DemangledName.erase(0, strlen(kOCLBuiltinName::WorkPrefix));
890     PreOps.push_back(ScopeWorkgroup);
891   } else if (DemangledName.find(kOCLBuiltinName::SubGroupPrefix) == 0) {
892     DemangledName.erase(0, strlen(kOCLBuiltinName::SubPrefix));
893     PreOps.push_back(ScopeSubgroup);
894   } else
895     return;
896 
897   if (DemangledName != kOCLBuiltinName::WaitGroupEvent) {
898     StringRef GroupOp = DemangledName;
899     GroupOp = GroupOp.drop_front(strlen(kSPIRVName::GroupPrefix));
900     SPIRSPIRVGroupOperationMap::foreach_conditional([&](const std::string &S,
901         SPIRVGroupOperationKind G){
902       if (!GroupOp.startswith(S))
903         return true; // continue
904       PreOps.push_back(G);
905       StringRef Op = GroupOp.drop_front(S.size() + 1);
906       assert(!Op.empty() && "Invalid OpenCL group builtin function");
907       char OpTyC = 0;
908       auto NeedSign = Op == "max" || Op == "min";
909       auto OpTy = F->getReturnType();
910       if (OpTy->isFloatingPointTy())
911         OpTyC = 'f';
912       else if (OpTy->isIntegerTy()) {
913         if (!NeedSign)
914           OpTyC = 'i';
915         else {
916           if (isLastFuncParamSigned(F->getName()))
917             OpTyC = 's';
918           else
919             OpTyC = 'u';
920         }
921       } else
922         llvm_unreachable("Invalid OpenCL group builtin argument type");
923 
924       DemangledName = std::string(kSPIRVName::GroupPrefix) + OpTyC + Op.str();
925       return false; // break out of loop
926     });
927   }
928 
929   bool IsGroupAllAny = (DemangledName.find("_all") != std::string::npos ||
930                         DemangledName.find("_any") != std::string::npos);
931 
932   auto Consts = getInt32(M, PreOps);
933   OCLBuiltinTransInfo Info;
934   if (IsGroupAllAny)
935     Info.RetTy = Type::getInt1Ty(*Ctx);
936   Info.UniqName = DemangledName;
937   Info.PostProc = [=](std::vector<Value *> &Ops) {
938     if (IsGroupAllAny) {
939       IRBuilder<> IRB(CI);
940       Ops[0] =
941           IRB.CreateICmpNE(Ops[0], ConstantInt::get(Type::getInt32Ty(*Ctx), 0));
942     }
943     size_t E = Ops.size();
944     if (DemangledName == "group_broadcast" && E > 2) {
945       assert(E == 3 || E == 4);
946       makeVector(CI, Ops, std::make_pair(Ops.begin() + 1, Ops.end()));
947     }
948     Ops.insert(Ops.begin(), Consts.begin(), Consts.end());
949   };
950   transBuiltin(CI, Info);
951 }
952 
953 void
transBuiltin(CallInst * CI,OCLBuiltinTransInfo & Info)954 OCL20ToSPIRV::transBuiltin(CallInst* CI,
955     OCLBuiltinTransInfo& Info) {
956   AttributeSet Attrs = CI->getCalledFunction()->getAttributes();
957   Op OC = OpNop;
958   unsigned ExtOp = ~0U;
959   if (StringRef(Info.UniqName).startswith(kSPIRVName::Prefix))
960       return;
961   if (OCLSPIRVBuiltinMap::find(Info.UniqName, &OC))
962     Info.UniqName = getSPIRVFuncName(OC);
963   else if ((ExtOp = getExtOp(Info.MangledName, Info.UniqName)) != ~0U)
964     Info.UniqName = getSPIRVExtFuncName(SPIRVEIS_OpenCL, ExtOp);
965   else
966     return;
967   if (!Info.RetTy)
968     mutateCallInstSPIRV(M, CI,
969                         [=](CallInst *, std::vector<Value *> &Args) {
970                           Info.PostProc(Args);
971                           return Info.UniqName + Info.Postfix;
972                         },
973                         &Attrs);
974   else
975     mutateCallInstSPIRV(
976         M, CI,
977         [=](CallInst *, std::vector<Value *> &Args, Type *&RetTy) {
978           Info.PostProc(Args);
979           RetTy = Info.RetTy;
980           return Info.UniqName + Info.Postfix;
981         },
982         [=](CallInst *NewCI) -> Instruction * {
983           if (NewCI->getType()->isIntegerTy() && CI->getType()->isIntegerTy())
984             return CastInst::CreateIntegerCast(NewCI, CI->getType(),
985                                                Info.isRetSigned, "", CI);
986           else
987             return CastInst::CreatePointerBitCastOrAddrSpaceCast(
988                 NewCI, CI->getType(), "", CI);
989         },
990         &Attrs);
991 }
992 
993 void
visitCallPipeBuiltin(CallInst * CI,StringRef MangledName,const std::string & DemangledName)994 OCL20ToSPIRV::visitCallPipeBuiltin(CallInst* CI,
995     StringRef MangledName, const std::string& DemangledName) {
996   std::string NewName = DemangledName;
997   // Transform OpenCL read_pipe/write_pipe builtin function names
998   // with reserve_id argument to reserved_read_pipe/reserved_write_pipe.
999   if ((DemangledName.find(kOCLBuiltinName::ReadPipe) == 0 ||
1000       DemangledName.find(kOCLBuiltinName::WritePipe) == 0)
1001       && CI->getNumArgOperands() > 4)
1002     NewName = std::string(kSPIRVName::ReservedPrefix) + DemangledName;
1003   OCLBuiltinTransInfo Info;
1004   Info.UniqName = NewName;
1005   transBuiltin(CI, Info);
1006 }
1007 
visitCallReadImageMSAA(CallInst * CI,StringRef MangledName,const std::string & DemangledName)1008 void OCL20ToSPIRV::visitCallReadImageMSAA(CallInst *CI, StringRef MangledName,
1009                                           const std::string &DemangledName) {
1010   assert(MangledName.find("msaa") != StringRef::npos);
1011   AttributeSet Attrs = CI->getCalledFunction()->getAttributes();
1012   mutateCallInstSPIRV(
1013       M, CI,
1014       [=](CallInst *, std::vector<Value *> &Args) {
1015         Args.insert(Args.begin() + 2, getInt32(M, ImageOperandsSampleMask));
1016         return getSPIRVFuncName(OpImageRead,
1017                                 std::string(kSPIRVPostfix::ExtDivider) +
1018                                     getPostfixForReturnType(CI));
1019       },
1020       &Attrs);
1021 }
1022 
visitCallReadImageWithSampler(CallInst * CI,StringRef MangledName,const std::string & DemangledName)1023 void OCL20ToSPIRV::visitCallReadImageWithSampler(
1024     CallInst *CI, StringRef MangledName, const std::string &DemangledName) {
1025   assert (MangledName.find(kMangledName::Sampler) != StringRef::npos);
1026   AttributeSet Attrs = CI->getCalledFunction()->getAttributes();
1027   bool isRetScalar = !CI->getType()->isVectorTy();
1028   mutateCallInstSPIRV(
1029       M, CI,
1030       [=](CallInst *, std::vector<Value *> &Args, Type *&Ret) {
1031         auto ImageTy = getAnalysis<OCLTypeToSPIRV>().getAdaptedType(Args[0]);
1032         if (isOCLImageType(ImageTy))
1033           ImageTy = getSPIRVImageTypeFromOCL(M, ImageTy);
1034         auto SampledImgTy = getSPIRVTypeByChangeBaseTypeName(
1035             M, ImageTy, kSPIRVTypeName::Image, kSPIRVTypeName::SampledImg);
1036         Value *SampledImgArgs[] = {Args[0], Args[1]};
1037         auto SampledImg = addCallInstSPIRV(
1038             M, getSPIRVFuncName(OpSampledImage), SampledImgTy, SampledImgArgs,
1039             nullptr, CI, kSPIRVName::TempSampledImage);
1040 
1041         Args[0] = SampledImg;
1042         Args.erase(Args.begin() + 1, Args.begin() + 2);
1043 
1044         switch (Args.size()) {
1045         case 2: // no lod
1046           Args.push_back(getInt32(M, ImageOperandsMask::ImageOperandsLodMask));
1047           Args.push_back(getFloat32(M, 0.f));
1048           break;
1049         case 3: // explicit lod
1050           Args.insert(Args.begin() + 2,
1051                       getInt32(M, ImageOperandsMask::ImageOperandsLodMask));
1052           break;
1053         case 4: // gradient
1054           Args.insert(Args.begin() + 2,
1055                       getInt32(M, ImageOperandsMask::ImageOperandsGradMask));
1056           break;
1057         default:
1058           assert(0 && "read_image* with unhandled number of args!");
1059         }
1060 
1061         // SPIR-V intruction always returns 4-element vector
1062         if (isRetScalar)
1063           Ret = VectorType::get(Ret, 4);
1064         return getSPIRVFuncName(OpImageSampleExplicitLod,
1065                                 std::string(kSPIRVPostfix::ExtDivider) +
1066                                     getPostfixForReturnType(Ret));
1067       },
1068       [&](CallInst *CI) -> Instruction * {
1069         if (isRetScalar)
1070           return ExtractElementInst::Create(CI, getSizet(M, 0), "",
1071                                             CI->getNextNode());
1072         return CI;
1073       },
1074       &Attrs);
1075 }
1076 
1077 void
visitCallGetImageSize(CallInst * CI,StringRef MangledName,const std::string & DemangledName)1078 OCL20ToSPIRV::visitCallGetImageSize(CallInst* CI,
1079     StringRef MangledName, const std::string& DemangledName) {
1080   AttributeSet Attrs = CI->getCalledFunction()->getAttributes();
1081   StringRef TyName;
1082   SmallVector<StringRef, 4> SubStrs;
1083   auto IsImg = isOCLImageType(CI->getArgOperand(0)->getType(), &TyName);
1084   (void)IsImg;  // prevent warning about unused variable in NDEBUG build
1085   assert(IsImg);
1086   std::string ImageTyName = TyName.str();
1087   if (hasAccessQualifiedName(TyName))
1088     ImageTyName.erase(ImageTyName.size() - 5, 3);
1089   auto Desc = map<SPIRVTypeImageDescriptor>(ImageTyName);
1090   unsigned Dim = getImageDimension(Desc.Dim) + Desc.Arrayed;
1091   assert(Dim > 0 && "Invalid image dimension.");
1092   mutateCallInstSPIRV(M, CI,
1093     [&](CallInst *, std::vector<Value *> &Args, Type *&Ret){
1094       assert(Args.size() == 1);
1095       Ret = CI->getType()->isIntegerTy(64) ? Type::getInt64Ty(*Ctx)
1096                                            : Type::getInt32Ty(*Ctx);
1097       if (Dim > 1)
1098         Ret = VectorType::get(Ret, Dim);
1099       if (Desc.Dim == DimBuffer)
1100         return getSPIRVFuncName(OpImageQuerySize, CI->getType());
1101       else {
1102         Args.push_back(getInt32(M, 0));
1103         return getSPIRVFuncName(OpImageQuerySizeLod, CI->getType());
1104       }
1105     },
1106     [&](CallInst *NCI)->Instruction * {
1107       if (Dim == 1)
1108         return NCI;
1109       if (DemangledName == kOCLBuiltinName::GetImageDim) {
1110         if (Desc.Dim == Dim3D) {
1111           auto ZeroVec = ConstantVector::getSplat(3,
1112             Constant::getNullValue(NCI->getType()->getVectorElementType()));
1113           Constant *Index[] = {getInt32(M, 0), getInt32(M, 1),
1114               getInt32(M, 2), getInt32(M, 3)};
1115           return new ShuffleVectorInst(NCI, ZeroVec,
1116              ConstantVector::get(Index), "", CI);
1117 
1118         } else if (Desc.Dim == Dim2D && Desc.Arrayed) {
1119           Constant *Index[] = {getInt32(M, 0), getInt32(M, 1)};
1120           Constant *mask = ConstantVector::get(Index);
1121           return new ShuffleVectorInst(NCI, UndefValue::get(NCI->getType()),
1122                                        mask, NCI->getName(), CI);
1123         }
1124         return NCI;
1125       }
1126       unsigned I = StringSwitch<unsigned>(DemangledName)
1127           .Case(kOCLBuiltinName::GetImageWidth, 0)
1128           .Case(kOCLBuiltinName::GetImageHeight, 1)
1129           .Case(kOCLBuiltinName::GetImageDepth, 2)
1130           .Case(kOCLBuiltinName::GetImageArraySize, Dim - 1);
1131       return ExtractElementInst::Create(NCI, getUInt32(M, I), "",
1132           NCI->getNextNode());
1133     },
1134   &Attrs);
1135 }
1136 
1137 /// Remove trivial conversion functions
1138 bool
eraseUselessConvert(CallInst * CI,const std::string & MangledName,const std::string & DemangledName)1139 OCL20ToSPIRV::eraseUselessConvert(CallInst *CI,
1140     const std::string &MangledName,
1141     const std::string &DemangledName) {
1142   auto TargetTy = CI->getType();
1143   auto SrcTy = CI->getArgOperand(0)->getType();
1144   if (isa<VectorType>(TargetTy))
1145     TargetTy = TargetTy->getVectorElementType();
1146   if (isa<VectorType>(SrcTy))
1147     SrcTy = SrcTy->getVectorElementType();
1148   if (TargetTy == SrcTy) {
1149     if (isa<IntegerType>(TargetTy) &&
1150         DemangledName.find("_sat") != std::string::npos &&
1151         isLastFuncParamSigned(MangledName) != (DemangledName[8] != 'u'))
1152       return false;
1153     CI->getArgOperand(0)->takeName(CI);
1154     SPIRVDBG(dbgs() << "[regularizeOCLConvert] " << *CI << " <- " <<
1155         *CI->getArgOperand(0) << '\n');
1156     CI->replaceAllUsesWith(CI->getArgOperand(0));
1157     ValuesToDelete.insert(CI);
1158     ValuesToDelete.insert(CI->getCalledFunction());
1159     return true;
1160   }
1161   return false;
1162 }
1163 
1164 void
visitCallBuiltinSimple(CallInst * CI,StringRef MangledName,const std::string & DemangledName)1165 OCL20ToSPIRV::visitCallBuiltinSimple(CallInst* CI,
1166     StringRef MangledName, const std::string& DemangledName) {
1167   OCLBuiltinTransInfo Info;
1168   Info.MangledName = MangledName.str();
1169   Info.UniqName = DemangledName;
1170   transBuiltin(CI, Info);
1171 }
1172 
1173 /// Translates OCL work-item builtin functions to SPIRV builtin variables.
1174 /// Function like get_global_id(i) -> x = load GlobalInvocationId; extract x, i
1175 /// Function like get_work_dim() -> load WorkDim
transWorkItemBuiltinsToVariables()1176 void OCL20ToSPIRV::transWorkItemBuiltinsToVariables() {
1177   DEBUG(dbgs() << "Enter transWorkItemBuiltinsToVariables\n");
1178   std::vector<Function *> WorkList;
1179   for (auto I = M->begin(), E = M->end(); I != E; ++I) {
1180     std::string DemangledName;
1181     if (!oclIsBuiltin(I->getName(), &DemangledName))
1182       continue;
1183     DEBUG(dbgs() << "Function demangled name: " << DemangledName << '\n');
1184     std::string BuiltinVarName;
1185     SPIRVBuiltinVariableKind BVKind;
1186     if (!SPIRSPIRVBuiltinVariableMap::find(DemangledName, &BVKind))
1187       continue;
1188     BuiltinVarName = std::string(kSPIRVName::Prefix) +
1189         SPIRVBuiltInNameMap::map(BVKind);
1190     DEBUG(dbgs() << "builtin variable name: " << BuiltinVarName << '\n');
1191     bool IsVec = I->getFunctionType()->getNumParams() > 0;
1192     Type *GVType = IsVec ? VectorType::get(I->getReturnType(),3) :
1193         I->getReturnType();
1194     auto BV = new GlobalVariable(*M, GVType,
1195         true,
1196         GlobalValue::ExternalLinkage,
1197         nullptr, BuiltinVarName,
1198         0,
1199         GlobalVariable::NotThreadLocal,
1200         SPIRAS_Constant);
1201     std::vector<Instruction *> InstList;
1202     for (auto UI = I->user_begin(), UE = I->user_end(); UI != UE; ++UI) {
1203       auto CI = dyn_cast<CallInst>(*UI);
1204       assert(CI && "invalid instruction");
1205       Value * NewValue = new LoadInst(BV, "", CI);
1206       DEBUG(dbgs() << "Transform: " << *CI << " => " << *NewValue << '\n');
1207       if (IsVec) {
1208         NewValue = ExtractElementInst::Create(NewValue,
1209           CI->getArgOperand(0),
1210           "", CI);
1211         DEBUG(dbgs() << *NewValue << '\n');
1212       }
1213       NewValue->takeName(CI);
1214       CI->replaceAllUsesWith(NewValue);
1215       InstList.push_back(CI);
1216     }
1217     for (auto &Inst:InstList) {
1218       Inst->dropAllReferences();
1219       Inst->removeFromParent();
1220     }
1221     WorkList.push_back(static_cast<Function*>(I));
1222   }
1223   for (auto &I:WorkList) {
1224     I->dropAllReferences();
1225     I->removeFromParent();
1226   }
1227 }
1228 
1229 void
visitCallReadWriteImage(CallInst * CI,StringRef MangledName,const std::string & DemangledName)1230 OCL20ToSPIRV::visitCallReadWriteImage(CallInst* CI,
1231     StringRef MangledName, const std::string& DemangledName) {
1232   OCLBuiltinTransInfo Info;
1233   if (DemangledName.find(kOCLBuiltinName::ReadImage) == 0)
1234     Info.UniqName = kOCLBuiltinName::ReadImage;
1235 
1236   if (DemangledName.find(kOCLBuiltinName::WriteImage) == 0)
1237   {
1238     Info.UniqName = kOCLBuiltinName::WriteImage;
1239     Info.PostProc = [&](std::vector<Value*> &Args) {
1240         if (Args.size() == 4) // write with lod
1241         {
1242             auto Lod = Args[2];
1243             Args.erase(Args.begin() + 2);
1244             Args.push_back(getInt32(M, ImageOperandsMask::ImageOperandsLodMask));
1245             Args.push_back(Lod);
1246         }
1247     };
1248   }
1249 
1250   transBuiltin(CI, Info);
1251 }
1252 
1253 void
visitCallToAddr(CallInst * CI,StringRef MangledName,const std::string & DemangledName)1254 OCL20ToSPIRV::visitCallToAddr(CallInst* CI, StringRef MangledName,
1255     const std::string &DemangledName) {
1256   auto AddrSpace = static_cast<SPIRAddressSpace>(
1257       CI->getType()->getPointerAddressSpace());
1258   OCLBuiltinTransInfo Info;
1259   Info.UniqName = DemangledName;
1260   Info.Postfix = std::string(kSPIRVPostfix::Divider) + "To" +
1261       SPIRAddrSpaceCapitalizedNameMap::map(AddrSpace);
1262   auto StorageClass = addInt32(SPIRSPIRVAddrSpaceMap::map(AddrSpace));
1263   Info.RetTy = getInt8PtrTy(cast<PointerType>(CI->getType()));
1264   Info.PostProc = [=](std::vector<Value *> &Ops){
1265     auto P = Ops.back();
1266     Ops.pop_back();
1267     Ops.push_back(castToInt8Ptr(P, CI));
1268     Ops.push_back(StorageClass);
1269   };
1270   transBuiltin(CI, Info);
1271 }
1272 
visitCallRelational(CallInst * CI,const std::string & DemangledName)1273 void OCL20ToSPIRV::visitCallRelational(CallInst *CI,
1274                                        const std::string &DemangledName) {
1275   AttributeSet Attrs = CI->getCalledFunction()->getAttributes();
1276   Op OC = OpNop;
1277   OCLSPIRVBuiltinMap::find(DemangledName, &OC);
1278   std::string SPIRVName = getSPIRVFuncName(OC);
1279   mutateCallInstSPIRV(
1280       M, CI,
1281       [=](CallInst *, std::vector<Value *> &Args, Type *&Ret) {
1282         Ret = Type::getInt1Ty(*Ctx);
1283         if (CI->getOperand(0)->getType()->isVectorTy())
1284           Ret = VectorType::get(
1285               Type::getInt1Ty(*Ctx),
1286               CI->getOperand(0)->getType()->getVectorNumElements());
1287         return SPIRVName;
1288       },
1289       [=](CallInst *NewCI) -> Instruction * {
1290         Value *False = nullptr, *True = nullptr;
1291         if (NewCI->getType()->isVectorTy()) {
1292           Type *IntTy = Type::getInt32Ty(*Ctx);
1293           if (cast<VectorType>(NewCI->getOperand(0)->getType())
1294                   ->getElementType()
1295                   ->isDoubleTy())
1296             IntTy = Type::getInt64Ty(*Ctx);
1297           if (cast<VectorType>(NewCI->getOperand(0)->getType())
1298                   ->getElementType()
1299                   ->isHalfTy())
1300             IntTy = Type::getInt16Ty(*Ctx);
1301           Type *VTy = VectorType::get(IntTy,
1302                                       NewCI->getType()->getVectorNumElements());
1303           False = Constant::getNullValue(VTy);
1304           True = Constant::getAllOnesValue(VTy);
1305         } else {
1306           False = getInt32(M, 0);
1307           True = getInt32(M, 1);
1308         }
1309         return SelectInst::Create(NewCI, True, False, "", NewCI->getNextNode());
1310       },
1311       &Attrs);
1312 }
1313 
1314 void
visitCallVecLoadStore(CallInst * CI,StringRef MangledName,const std::string & OrigDemangledName)1315 OCL20ToSPIRV::visitCallVecLoadStore(CallInst* CI,
1316     StringRef MangledName, const std::string& OrigDemangledName) {
1317   std::vector<int> PreOps;
1318   std::string DemangledName = OrigDemangledName;
1319   if (DemangledName.find(kOCLBuiltinName::VLoadPrefix) == 0 &&
1320       DemangledName != kOCLBuiltinName::VLoadHalf) {
1321     SPIRVWord Width = getVecLoadWidth(DemangledName);
1322     SPIRVDBG(spvdbgs() << "[visitCallVecLoadStore] DemangledName: " <<
1323         DemangledName << " Width: " << Width << '\n');
1324     PreOps.push_back(Width);
1325   } else if (DemangledName.find(kOCLBuiltinName::RoundingPrefix)
1326       != std::string::npos) {
1327     auto R = SPIRSPIRVFPRoundingModeMap::map(DemangledName.substr(
1328         DemangledName.find(kOCLBuiltinName::RoundingPrefix) + 1, 3));
1329     PreOps.push_back(R);
1330   }
1331 
1332   if (DemangledName.find(kOCLBuiltinName::VLoadAPrefix) == 0)
1333     transVecLoadStoreName(DemangledName, kOCLBuiltinName::VLoadAPrefix, true);
1334   else
1335     transVecLoadStoreName(DemangledName, kOCLBuiltinName::VLoadPrefix, false);
1336 
1337   if (DemangledName.find(kOCLBuiltinName::VStoreAPrefix) == 0)
1338     transVecLoadStoreName(DemangledName, kOCLBuiltinName::VStoreAPrefix, true);
1339   else
1340     transVecLoadStoreName(DemangledName, kOCLBuiltinName::VStorePrefix, false);
1341 
1342 
1343   auto Consts = getInt32(M, PreOps);
1344   OCLBuiltinTransInfo Info;
1345   Info.MangledName = MangledName;
1346   Info.UniqName = DemangledName;
1347   if (DemangledName.find(kOCLBuiltinName::VLoadPrefix) == 0)
1348     Info.Postfix = std::string(kSPIRVPostfix::ExtDivider) +
1349       getPostfixForReturnType(CI);
1350   Info.PostProc = [=](std::vector<Value *> &Ops){
1351     Ops.insert(Ops.end(), Consts.begin(), Consts.end());
1352   };
1353   transBuiltin(CI, Info);
1354 }
1355 
visitCallGetFence(CallInst * CI,StringRef MangledName,const std::string & DemangledName)1356 void OCL20ToSPIRV::visitCallGetFence(CallInst *CI, StringRef MangledName,
1357                                      const std::string &DemangledName) {
1358   AttributeSet Attrs = CI->getCalledFunction()->getAttributes();
1359   Op OC = OpNop;
1360   OCLSPIRVBuiltinMap::find(DemangledName, &OC);
1361   std::string SPIRVName = getSPIRVFuncName(OC);
1362   mutateCallInstSPIRV(M, CI, [=](CallInst *, std::vector<Value *> &Args,
1363                                  Type *&Ret) { return SPIRVName; },
1364             [=](CallInst *NewCI) -> Instruction * {
1365               return BinaryOperator::CreateLShr(NewCI, getInt32(M, 8), "", CI);
1366             },
1367             &Attrs);
1368 }
1369 
visitCallDot(CallInst * CI)1370 void OCL20ToSPIRV::visitCallDot(CallInst *CI) {
1371   IRBuilder<> Builder(CI);
1372   Value *FMulVal = Builder.CreateFMul(CI->getOperand(0), CI->getOperand(1));
1373   CI->replaceAllUsesWith(FMulVal);
1374   CI->dropAllReferences();
1375   CI->removeFromParent();
1376 }
1377 
visitCallScalToVec(CallInst * CI,StringRef MangledName,const std::string & DemangledName)1378 void OCL20ToSPIRV::visitCallScalToVec(CallInst *CI, StringRef MangledName,
1379                                       const std::string &DemangledName) {
1380   // Check if all arguments have the same type - it's simple case.
1381   auto Uniform = true;
1382   auto IsArg0Vector = isa<VectorType>(CI->getOperand(0)->getType());
1383   for (unsigned I = 1, E = CI->getNumArgOperands(); Uniform && (I != E); ++I) {
1384     Uniform = isa<VectorType>(CI->getOperand(I)->getType()) == IsArg0Vector;
1385   }
1386   if (Uniform) {
1387     visitCallBuiltinSimple(CI, MangledName, DemangledName);
1388     return;
1389   }
1390 
1391   std::vector<unsigned int> VecPos;
1392   std::vector<unsigned int> ScalarPos;
1393   if (DemangledName == kOCLBuiltinName::FMin ||
1394       DemangledName == kOCLBuiltinName::FMax ||
1395       DemangledName == kOCLBuiltinName::Min ||
1396       DemangledName == kOCLBuiltinName::Max) {
1397     VecPos.push_back(0);
1398     ScalarPos.push_back(1);
1399   } else if (DemangledName == kOCLBuiltinName::Clamp) {
1400     VecPos.push_back(0);
1401     ScalarPos.push_back(1);
1402     ScalarPos.push_back(2);
1403   } else if (DemangledName == kOCLBuiltinName::Mix) {
1404     VecPos.push_back(0);
1405     VecPos.push_back(1);
1406     ScalarPos.push_back(2);
1407   } else if (DemangledName == kOCLBuiltinName::Step) {
1408     VecPos.push_back(1);
1409     ScalarPos.push_back(0);
1410   } else if (DemangledName == kOCLBuiltinName::SmoothStep) {
1411     VecPos.push_back(2);
1412     ScalarPos.push_back(0);
1413     ScalarPos.push_back(1);
1414   }
1415 
1416   AttributeSet Attrs = CI->getCalledFunction()->getAttributes();
1417   mutateCallInstSPIRV(
1418       M, CI,
1419       [=](CallInst *, std::vector<Value *> &Args) {
1420         Args.resize(VecPos.size() + ScalarPos.size());
1421         for (auto I : VecPos) {
1422           Args[I] = CI->getOperand(I);
1423         }
1424         auto VecArgWidth =
1425             CI->getOperand(VecPos[0])->getType()->getVectorNumElements();
1426         for (auto I : ScalarPos) {
1427           Instruction *Inst = InsertElementInst::Create(
1428               UndefValue::get(CI->getOperand(VecPos[0])->getType()),
1429               CI->getOperand(I), getInt32(M, 0), "", CI);
1430           Value *NewVec = new ShuffleVectorInst(
1431               Inst, UndefValue::get(CI->getOperand(VecPos[0])->getType()),
1432               ConstantVector::getSplat(VecArgWidth, getInt32(M, 0)), "", CI);
1433 
1434           Args[I] = NewVec;
1435         }
1436         return getSPIRVExtFuncName(SPIRVEIS_OpenCL,
1437                                    getExtOp(MangledName, DemangledName));
1438       },
1439       &Attrs);
1440 }
1441 
visitCallGetImageChannel(CallInst * CI,StringRef MangledName,const std::string & DemangledName,unsigned int Offset)1442 void OCL20ToSPIRV::visitCallGetImageChannel(CallInst *CI, StringRef MangledName,
1443                                             const std::string &DemangledName,
1444                                             unsigned int Offset) {
1445   AttributeSet Attrs = CI->getCalledFunction()->getAttributes();
1446   Op OC = OpNop;
1447   OCLSPIRVBuiltinMap::find(DemangledName, &OC);
1448   std::string SPIRVName = getSPIRVFuncName(OC);
1449   mutateCallInstSPIRV(M, CI, [=](CallInst *, std::vector<Value *> &Args,
1450                                  Type *&Ret) { return SPIRVName; },
1451                       [=](CallInst *NewCI) -> Instruction * {
1452                         return BinaryOperator::CreateAdd(
1453                             NewCI, getInt32(M, Offset), "", CI);
1454                       },
1455                       &Attrs);
1456 }
1457 }
1458 
1459 INITIALIZE_PASS_BEGIN(OCL20ToSPIRV, "cl20tospv", "Transform OCL 2.0 to SPIR-V",
1460     false, false)
INITIALIZE_PASS_DEPENDENCY(OCLTypeToSPIRV)1461 INITIALIZE_PASS_DEPENDENCY(OCLTypeToSPIRV)
1462 INITIALIZE_PASS_END(OCL20ToSPIRV, "cl20tospv", "Transform OCL 2.0 to SPIR-V",
1463     false, false)
1464 
1465 ModulePass *llvm::createOCL20ToSPIRV() {
1466   return new OCL20ToSPIRV();
1467 }
1468