1 //===- SPIRVUtil.cpp - SPIR-V Utilities -------------------------*- C++ -*-===//
2 //
3 //                     The LLVM/SPIRV Translator
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 // Copyright (c) 2014 Advanced Micro Devices, Inc. All rights reserved.
9 //
10 // Permission is hereby granted, free of charge, to any person obtaining a
11 // copy of this software and associated documentation files (the "Software"),
12 // to deal with the Software without restriction, including without limitation
13 // the rights to use, copy, modify, merge, publish, distribute, sublicense,
14 // and/or sell copies of the Software, and to permit persons to whom the
15 // Software is furnished to do so, subject to the following conditions:
16 //
17 // Redistributions of source code must retain the above copyright notice,
18 // this list of conditions and the following disclaimers.
19 // Redistributions in binary form must reproduce the above copyright notice,
20 // this list of conditions and the following disclaimers in the documentation
21 // and/or other materials provided with the distribution.
22 // Neither the names of Advanced Micro Devices, Inc., nor the names of its
23 // contributors may be used to endorse or promote products derived from this
24 // Software without specific prior written permission.
25 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
26 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
27 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
28 // CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
29 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
30 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH
31 // THE SOFTWARE.
32 //
33 //===----------------------------------------------------------------------===//
34 /// \file
35 ///
36 /// This file defines utility classes and functions shared by SPIR-V
37 /// reader/writer.
38 ///
39 //===----------------------------------------------------------------------===//
40 
41 #include "SPIRVInternal.h"
42 #include "libSPIRV/SPIRVDecorate.h"
43 #include "libSPIRV/SPIRVValue.h"
44 #include "SPIRVMDWalker.h"
45 #include "OCLUtil.h"
46 
47 #include "llvm/ADT/StringSwitch.h"
48 #include "llvm/Bitcode/ReaderWriter.h"
49 #include "llvm/IR/IRBuilder.h"
50 #include "llvm/Support/CommandLine.h"
51 #include "llvm/Support/Debug.h"
52 #include "llvm/Support/ErrorHandling.h"
53 #include "llvm/Support/FileSystem.h"
54 #include "llvm/Support/ToolOutputFile.h"
55 #include "llvm/Support/raw_ostream.h"
56 
57 #include <functional>
58 #include <sstream>
59 
60 #define DEBUG_TYPE "spirv"
61 
62 namespace SPIRV{
63 
64 #ifdef _SPIRV_SUPPORT_TEXT_FMT
65 cl::opt<bool, true>
66 UseTextFormat("spirv-text",
67     cl::desc("Use text format for SPIR-V for debugging purpose"),
68     cl::location(SPIRVUseTextFormat));
69 #endif
70 
71 #ifdef _SPIRVDBG
72 cl::opt<bool, true>
73 EnableDbgOutput("spirv-debug",
74     cl::desc("Enable SPIR-V debug output"),
75     cl::location(SPIRVDbgEnable));
76 #endif
77 
78 void
addFnAttr(LLVMContext * Context,CallInst * Call,Attribute::AttrKind Attr)79 addFnAttr(LLVMContext *Context, CallInst *Call, Attribute::AttrKind Attr) {
80   Call->addAttribute(AttributeSet::FunctionIndex, Attr);
81 }
82 
83 void
removeFnAttr(LLVMContext * Context,CallInst * Call,Attribute::AttrKind Attr)84 removeFnAttr(LLVMContext *Context, CallInst *Call, Attribute::AttrKind Attr) {
85   Call->removeAttribute(AttributeSet::FunctionIndex,
86       Attribute::get(*Context, Attr));
87 }
88 
89 Value *
removeCast(Value * V)90 removeCast(Value *V) {
91   auto Cast = dyn_cast<ConstantExpr>(V);
92   if (Cast && Cast->isCast()) {
93     return removeCast(Cast->getOperand(0));
94   }
95   if (auto Cast = dyn_cast<CastInst>(V))
96     return removeCast(Cast->getOperand(0));
97   return V;
98 }
99 
100 void
saveLLVMModule(Module * M,const std::string & OutputFile)101 saveLLVMModule(Module *M, const std::string &OutputFile) {
102   std::error_code EC;
103   tool_output_file Out(OutputFile.c_str(), EC, sys::fs::F_None);
104   if (EC) {
105     SPIRVDBG(errs() << "Fails to open output file: " << EC.message();)
106     return;
107   }
108 
109   WriteBitcodeToFile(M, Out.os());
110   Out.keep();
111 }
112 
113 std::string
mapLLVMTypeToOCLType(const Type * Ty,bool Signed)114 mapLLVMTypeToOCLType(const Type* Ty, bool Signed) {
115   if (Ty->isHalfTy())
116     return "half";
117   if (Ty->isFloatTy())
118     return "float";
119   if (Ty->isDoubleTy())
120     return "double";
121   if (auto intTy = dyn_cast<IntegerType>(Ty)) {
122     std::string SignPrefix;
123     std::string Stem;
124     if (!Signed)
125       SignPrefix = "u";
126     switch (intTy->getIntegerBitWidth()) {
127     case 8:
128       Stem = "char";
129       break;
130     case 16:
131       Stem = "short";
132       break;
133     case 32:
134       Stem = "int";
135       break;
136     case 64:
137       Stem = "long";
138       break;
139     default:
140       Stem = "invalid_type";
141       break;
142     }
143     return SignPrefix + Stem;
144   }
145   if (auto vecTy = dyn_cast<VectorType>(Ty)) {
146     Type* eleTy = vecTy->getElementType();
147     unsigned size = vecTy->getVectorNumElements();
148     std::stringstream ss;
149     ss << mapLLVMTypeToOCLType(eleTy, Signed) << size;
150     return ss.str();
151   }
152   return "invalid_type";
153 }
154 
155 std::string
mapSPIRVTypeToOCLType(SPIRVType * Ty,bool Signed)156 mapSPIRVTypeToOCLType(SPIRVType* Ty, bool Signed) {
157   if (Ty->isTypeFloat()) {
158     auto W = Ty->getBitWidth();
159     switch (W) {
160     case 16:
161       return "half";
162     case 32:
163       return "float";
164     case 64:
165       return "double";
166     default:
167       assert (0 && "Invalid floating pointer type");
168       return std::string("float") + W + "_t";
169     }
170   }
171   if (Ty->isTypeInt()) {
172     std::string SignPrefix;
173     std::string Stem;
174     if (!Signed)
175       SignPrefix = "u";
176     auto W = Ty->getBitWidth();
177     switch (W) {
178     case 8:
179       Stem = "char";
180       break;
181     case 16:
182       Stem = "short";
183       break;
184     case 32:
185       Stem = "int";
186       break;
187     case 64:
188       Stem = "long";
189       break;
190     default:
191       llvm_unreachable("Invalid integer type");
192       Stem = std::string("int") + W + "_t";
193       break;
194     }
195     return SignPrefix + Stem;
196   }
197   if (Ty->isTypeVector()) {
198     auto eleTy = Ty->getVectorComponentType();
199     auto size = Ty->getVectorComponentCount();
200     std::stringstream ss;
201     ss << mapSPIRVTypeToOCLType(eleTy, Signed) << size;
202     return ss.str();
203   }
204   llvm_unreachable("Invalid type");
205   return "unknown_type";
206 }
207 
208 PointerType*
getOrCreateOpaquePtrType(Module * M,const std::string & Name,unsigned AddrSpace)209 getOrCreateOpaquePtrType(Module *M, const std::string &Name,
210     unsigned AddrSpace) {
211   auto OpaqueType = M->getTypeByName(Name);
212   if (!OpaqueType)
213     OpaqueType = StructType::create(M->getContext(), Name);
214   return PointerType::get(OpaqueType, AddrSpace);
215 }
216 
217 PointerType*
getSamplerType(Module * M)218 getSamplerType(Module *M) {
219   return getOrCreateOpaquePtrType(M, getSPIRVTypeName(kSPIRVTypeName::Sampler),
220                                   SPIRAS_Constant);
221 }
222 
223 PointerType*
getPipeStorageType(Module * M)224 getPipeStorageType(Module* M) {
225   return getOrCreateOpaquePtrType(M, getSPIRVTypeName(
226                                         kSPIRVTypeName::PipeStorage),
227                                         SPIRAS_Constant);
228 }
229 
230 
231 void
getFunctionTypeParameterTypes(llvm::FunctionType * FT,std::vector<Type * > & ArgTys)232 getFunctionTypeParameterTypes(llvm::FunctionType* FT,
233     std::vector<Type*>& ArgTys) {
234   for (auto I = FT->param_begin(), E = FT->param_end(); I != E; ++I) {
235     ArgTys.push_back(*I);
236   }
237 }
238 
239 bool
isVoidFuncTy(FunctionType * FT)240 isVoidFuncTy(FunctionType *FT) {
241   return FT->getReturnType()->isVoidTy() && FT->getNumParams() == 0;
242 }
243 
244 bool
isPointerToOpaqueStructType(llvm::Type * Ty)245 isPointerToOpaqueStructType(llvm::Type* Ty) {
246   if (auto PT = dyn_cast<PointerType>(Ty))
247     if (auto ST = dyn_cast<StructType>(PT->getElementType()))
248       if (ST->isOpaque())
249         return true;
250   return false;
251 }
252 
253 bool
isPointerToOpaqueStructType(llvm::Type * Ty,const std::string & Name)254 isPointerToOpaqueStructType(llvm::Type* Ty, const std::string &Name) {
255   if (auto PT = dyn_cast<PointerType>(Ty))
256     if (auto ST = dyn_cast<StructType>(PT->getElementType()))
257       if (ST->isOpaque() && ST->getName() == Name)
258         return true;
259   return false;
260 }
261 
262 bool
isOCLImageType(llvm::Type * Ty,StringRef * Name)263 isOCLImageType(llvm::Type* Ty, StringRef *Name) {
264   if (auto PT = dyn_cast<PointerType>(Ty))
265     if (auto ST = dyn_cast<StructType>(PT->getElementType()))
266       if (ST->isOpaque()) {
267         auto FullName = ST->getName();
268         if (FullName.find(kSPR2TypeName::ImagePrefix) == 0) {
269           if (Name)
270             *Name = FullName.drop_front(strlen(kSPR2TypeName::OCLPrefix));
271           return true;
272         }
273       }
274   return false;
275 }
276 
277 /// \param BaseTyName is the type name as in spirv.BaseTyName.Postfixes
278 /// \param Postfix contains postfixes extracted from the SPIR-V image
279 ///   type name as spirv.BaseTyName.Postfixes.
280 bool
isSPIRVType(llvm::Type * Ty,StringRef BaseTyName,StringRef * Postfix)281 isSPIRVType(llvm::Type* Ty, StringRef BaseTyName, StringRef *Postfix) {
282   if (auto PT = dyn_cast<PointerType>(Ty))
283     if (auto ST = dyn_cast<StructType>(PT->getElementType()))
284       if (ST->isOpaque()) {
285         auto FullName = ST->getName();
286         std::string N = std::string(kSPIRVTypeName::PrefixAndDelim) +
287           BaseTyName.str();
288         if (FullName != N)
289           N = N + kSPIRVTypeName::Delimiter;
290         if (FullName.startswith(N)) {
291           if (Postfix)
292             *Postfix = FullName.drop_front(N.size());
293           return true;
294         }
295       }
296   return false;
297 }
298 
299 Function *
getOrCreateFunction(Module * M,Type * RetTy,ArrayRef<Type * > ArgTypes,StringRef Name,BuiltinFuncMangleInfo * Mangle,AttributeSet * Attrs,bool takeName)300 getOrCreateFunction(Module *M, Type *RetTy, ArrayRef<Type *> ArgTypes,
301     StringRef Name, BuiltinFuncMangleInfo *Mangle, AttributeSet *Attrs,
302     bool takeName) {
303   std::string MangledName = Name;
304   bool isVarArg = false;
305   if (Mangle) {
306     MangledName = mangleBuiltin(Name, ArgTypes, Mangle);
307     isVarArg = 0 <= Mangle->getVarArg();
308     if(isVarArg) ArgTypes = ArgTypes.slice(0, Mangle->getVarArg());
309   }
310   FunctionType *FT = FunctionType::get(RetTy, ArgTypes, isVarArg);
311   Function *F = M->getFunction(MangledName);
312   if (!takeName && F && F->getFunctionType() != FT && Mangle != nullptr) {
313     std::string S;
314     raw_string_ostream SS(S);
315     SS << "Error: Attempt to redefine function: " << *F << " => " <<
316         *FT << '\n';
317     report_fatal_error(SS.str(), false);
318   }
319   if (!F || F->getFunctionType() != FT) {
320     auto NewF = Function::Create(FT,
321       GlobalValue::ExternalLinkage,
322       MangledName,
323       M);
324     if (F && takeName) {
325       NewF->takeName(F);
326       DEBUG(dbgs() << "[getOrCreateFunction] Warning: taking function name\n");
327     }
328     if (NewF->getName() != MangledName) {
329       DEBUG(dbgs() << "[getOrCreateFunction] Warning: function name changed\n");
330     }
331     DEBUG(dbgs() << "[getOrCreateFunction] ";
332       if (F)
333         dbgs() << *F << " => ";
334       dbgs() << *NewF << '\n';
335       );
336     F = NewF;
337     F->setCallingConv(CallingConv::SPIR_FUNC);
338     if (Attrs)
339       F->setAttributes(*Attrs);
340   }
341   return F;
342 }
343 
344 std::vector<Value *>
getArguments(CallInst * CI,unsigned Start,unsigned End)345 getArguments(CallInst* CI, unsigned Start, unsigned End) {
346   std::vector<Value*> Args;
347   if (End == 0)
348     End = CI->getNumArgOperands();
349   for (; Start != End; ++Start) {
350     Args.push_back(CI->getArgOperand(Start));
351   }
352   return Args;
353 }
354 
getArgAsInt(CallInst * CI,unsigned I)355 uint64_t getArgAsInt(CallInst *CI, unsigned I){
356   return cast<ConstantInt>(CI->getArgOperand(I))->getZExtValue();
357 }
358 
getArgAsScope(CallInst * CI,unsigned I)359 Scope getArgAsScope(CallInst *CI, unsigned I){
360   return static_cast<Scope>(getArgAsInt(CI, I));
361 }
362 
getArgAsDecoration(CallInst * CI,unsigned I)363 Decoration getArgAsDecoration(CallInst *CI, unsigned I) {
364   return static_cast<Decoration>(getArgAsInt(CI, I));
365 }
366 
367 std::string
decorateSPIRVFunction(const std::string & S)368 decorateSPIRVFunction(const std::string &S) {
369   return std::string(kSPIRVName::Prefix) + S + kSPIRVName::Postfix;
370 }
371 
372 std::string
undecorateSPIRVFunction(const std::string & S)373 undecorateSPIRVFunction(const std::string& S) {
374   assert (S.find(kSPIRVName::Prefix) == 0);
375   const size_t Start = strlen(kSPIRVName::Prefix);
376   auto End = S.rfind(kSPIRVName::Postfix);
377   return S.substr(Start, End - Start);
378 }
379 
380 std::string
prefixSPIRVName(const std::string & S)381 prefixSPIRVName(const std::string &S) {
382   return std::string(kSPIRVName::Prefix) + S;
383 }
384 
385 StringRef
dePrefixSPIRVName(StringRef R,SmallVectorImpl<StringRef> & Postfix)386 dePrefixSPIRVName(StringRef R,
387     SmallVectorImpl<StringRef> &Postfix) {
388   const size_t Start = strlen(kSPIRVName::Prefix);
389   if (!R.startswith(kSPIRVName::Prefix))
390     return R;
391   R = R.drop_front(Start);
392   R.split(Postfix, "_", -1, false);
393   auto Name = Postfix.front();
394   Postfix.erase(Postfix.begin());
395   return Name;
396 }
397 
398 std::string
getSPIRVFuncName(Op OC,StringRef PostFix)399 getSPIRVFuncName(Op OC, StringRef PostFix) {
400   return prefixSPIRVName(getName(OC) + PostFix.str());
401 }
402 
403 std::string
getSPIRVFuncName(Op OC,const Type * pRetTy,bool IsSigned)404 getSPIRVFuncName(Op OC, const Type *pRetTy, bool IsSigned) {
405   return prefixSPIRVName(getName(OC) + kSPIRVPostfix::Divider +
406                          getPostfixForReturnType(pRetTy, false));
407 }
408 
409 std::string
getSPIRVExtFuncName(SPIRVExtInstSetKind Set,unsigned ExtOp,StringRef PostFix)410 getSPIRVExtFuncName(SPIRVExtInstSetKind Set, unsigned ExtOp,
411     StringRef PostFix) {
412   std::string ExtOpName;
413   switch(Set) {
414   default:
415     llvm_unreachable("invalid extended instruction set");
416     ExtOpName = "unknown";
417     break;
418   case SPIRVEIS_OpenCL:
419     ExtOpName = getName(static_cast<OCLExtOpKind>(ExtOp));
420     break;
421   }
422   return prefixSPIRVName(SPIRVExtSetShortNameMap::map(Set)
423       + '_' + ExtOpName + PostFix.str());
424 }
425 
426 SPIRVDecorate *
mapPostfixToDecorate(StringRef Postfix,SPIRVEntry * Target)427 mapPostfixToDecorate(StringRef Postfix, SPIRVEntry *Target) {
428   if (Postfix == kSPIRVPostfix::Sat)
429     return new SPIRVDecorate(spv::DecorationSaturatedConversion, Target);
430 
431   if (Postfix.startswith(kSPIRVPostfix::Rt))
432     return new SPIRVDecorate(spv::DecorationFPRoundingMode, Target,
433       map<SPIRVFPRoundingModeKind>(Postfix.str()));
434 
435   return nullptr;
436 }
437 
438 SPIRVValue *
addDecorations(SPIRVValue * Target,const SmallVectorImpl<std::string> & Decs)439 addDecorations(SPIRVValue *Target, const SmallVectorImpl<std::string>& Decs){
440   for (auto &I:Decs)
441     if (auto Dec = mapPostfixToDecorate(I, Target))
442       Target->addDecorate(Dec);
443   return Target;
444 }
445 
446 std::string
getPostfix(Decoration Dec,unsigned Value)447 getPostfix(Decoration Dec, unsigned Value) {
448   switch(Dec) {
449   default:
450     llvm_unreachable("not implemented");
451     return "unknown";
452   case spv::DecorationSaturatedConversion:
453     return kSPIRVPostfix::Sat;
454   case spv::DecorationFPRoundingMode:
455     return rmap<std::string>(static_cast<SPIRVFPRoundingModeKind>(Value));
456   }
457 }
458 
459 std::string
getPostfixForReturnType(CallInst * CI,bool IsSigned)460 getPostfixForReturnType(CallInst *CI, bool IsSigned) {
461   return getPostfixForReturnType(CI->getType(), IsSigned);
462 }
463 
getPostfixForReturnType(const Type * pRetTy,bool IsSigned)464 std::string getPostfixForReturnType(const Type *pRetTy, bool IsSigned) {
465   return std::string(kSPIRVPostfix::Return) +
466          mapLLVMTypeToOCLType(pRetTy, IsSigned);
467 }
468 
469 Op
getSPIRVFuncOC(const std::string & S,SmallVectorImpl<std::string> * Dec)470 getSPIRVFuncOC(const std::string& S, SmallVectorImpl<std::string> *Dec) {
471   Op OC;
472   SmallVector<StringRef, 2> Postfix;
473   std::string Name;
474   if (!oclIsBuiltin(S, &Name))
475     Name = S;
476   StringRef R(Name);
477   R = dePrefixSPIRVName(R, Postfix);
478   if (!getByName(R.str(), OC))
479     return OpNop;
480   if (Dec)
481     for (auto &I:Postfix)
482       Dec->push_back(I.str());
483   return OC;
484 }
485 
486 bool
getSPIRVBuiltin(const std::string & OrigName,spv::BuiltIn & B)487 getSPIRVBuiltin(const std::string &OrigName, spv::BuiltIn &B) {
488   SmallVector<StringRef, 2> Postfix;
489   StringRef R(OrigName);
490   R = dePrefixSPIRVName(R, Postfix);
491   assert(Postfix.empty() && "Invalid SPIR-V builtin name");
492   return getByName(R.str(), B);
493 }
494 
oclIsBuiltin(const StringRef & Name,std::string * DemangledName,bool isCPP)495 bool oclIsBuiltin(const StringRef &Name, std::string *DemangledName,
496                   bool isCPP) {
497   if (Name == "printf") {
498     if (DemangledName)
499       *DemangledName = Name;
500     return true;
501   }
502   if (!Name.startswith("_Z"))
503     return false;
504   if (!DemangledName)
505     return true;
506   // OpenCL C++ built-ins are declared in cl namespace.
507   // TODO: consider using 'St' abbriviation for cl namespace mangling.
508   // Similar to ::std:: in C++.
509   if (isCPP) {
510     if (!Name.startswith("_ZN"))
511       return false;
512     // Skip CV and ref qualifiers.
513     size_t NameSpaceStart = Name.find_first_not_of("rVKRO", 3);
514     // All built-ins are in the ::cl:: namespace.
515     if (Name.substr(NameSpaceStart, 11) != "2cl7__spirv")
516       return false;
517     size_t DemangledNameLenStart = NameSpaceStart + 11;
518     size_t Start = Name.find_first_not_of("0123456789", DemangledNameLenStart);
519     size_t Len = 0;
520     Name.substr(DemangledNameLenStart, Start - DemangledNameLenStart)
521         .getAsInteger(10, Len);
522     *DemangledName = Name.substr(Start, Len);
523   } else {
524     size_t Start = Name.find_first_not_of("0123456789", 2);
525     size_t Len = 0;
526     Name.substr(2, Start - 2).getAsInteger(10, Len);
527     *DemangledName = Name.substr(Start, Len);
528   }
529   return true;
530 }
531 
532 // Check if a mangled type name is unsigned
isMangledTypeUnsigned(char Mangled)533 bool isMangledTypeUnsigned(char Mangled) {
534   return Mangled == 'h'    /* uchar */
535          || Mangled == 't' /* ushort */
536          || Mangled == 'j' /* uint */
537          || Mangled == 'm' /* ulong */;
538 }
539 
540 // Check if a mangled type name is signed
isMangledTypeSigned(char Mangled)541 bool isMangledTypeSigned(char Mangled) {
542   return Mangled == 'c'    /* char */
543          || Mangled == 'a' /* signed char */
544          || Mangled == 's' /* short */
545          || Mangled == 'i' /* int */
546          || Mangled == 'l' /* long */;
547 }
548 
549 // Check if a mangled type name is floating point (excludes half)
isMangledTypeFP(char Mangled)550 bool isMangledTypeFP(char Mangled) {
551   return Mangled == 'f'     /* float */
552          || Mangled == 'd'; /* double */
553 }
554 
555 // Check if a mangled type name is half
isMangledTypeHalf(std::string Mangled)556 bool isMangledTypeHalf(std::string Mangled) {
557   return Mangled == "Dh"; /* half */
558 }
559 
560 void
eraseSubstitutionFromMangledName(std::string & MangledName)561 eraseSubstitutionFromMangledName(std::string& MangledName) {
562   auto Len = MangledName.length();
563   while (Len >= 2 && MangledName.substr(Len - 2, 2) == "S_") {
564     Len -= 2;
565     MangledName.erase(Len, 2);
566   }
567 }
568 
LastFuncParamType(const std::string & MangledName)569 ParamType LastFuncParamType(const std::string &MangledName) {
570   auto Copy = MangledName;
571   eraseSubstitutionFromMangledName(Copy);
572   char Mangled = Copy.back();
573   std::string Mangled2 = Copy.substr(Copy.size() - 2);
574 
575   if (isMangledTypeFP(Mangled) || isMangledTypeHalf(Mangled2)) {
576     return ParamType::FLOAT;
577   } else if (isMangledTypeUnsigned(Mangled)) {
578     return ParamType::UNSIGNED;
579   } else if (isMangledTypeSigned(Mangled)) {
580     return ParamType::SIGNED;
581   }
582 
583   return ParamType::UNKNOWN;
584 }
585 
586 // Check if the last argument is signed
587 bool
isLastFuncParamSigned(const std::string & MangledName)588 isLastFuncParamSigned(const std::string& MangledName) {
589   return LastFuncParamType(MangledName) == ParamType::SIGNED;
590 }
591 
592 
593 // Check if a mangled function name contains unsigned atomic type
594 bool
containsUnsignedAtomicType(StringRef Name)595 containsUnsignedAtomicType(StringRef Name) {
596   auto Loc = Name.find(kMangledName::AtomicPrefixIncoming);
597   if (Loc == StringRef::npos)
598     return false;
599   return isMangledTypeUnsigned(Name[Loc + strlen(
600       kMangledName::AtomicPrefixIncoming)]);
601 }
602 
603 bool
isFunctionPointerType(Type * T)604 isFunctionPointerType(Type *T) {
605   if (isa<PointerType>(T) &&
606       isa<FunctionType>(T->getPointerElementType())) {
607     return true;
608   }
609   return false;
610 }
611 
612 bool
hasFunctionPointerArg(Function * F,Function::arg_iterator & AI)613 hasFunctionPointerArg(Function *F, Function::arg_iterator& AI) {
614   AI = F->arg_begin();
615   for (auto AE = F->arg_end(); AI != AE; ++AI) {
616     DEBUG(dbgs() << "[hasFuncPointerArg] " << *AI << '\n');
617     if (isFunctionPointerType(AI->getType())) {
618       return true;
619     }
620   }
621   return false;
622 }
623 
624 Constant *
castToVoidFuncPtr(Function * F)625 castToVoidFuncPtr(Function *F) {
626   auto T = getVoidFuncPtrType(F->getParent());
627   return ConstantExpr::getBitCast(F, T);
628 }
629 
630 bool
hasArrayArg(Function * F)631 hasArrayArg(Function *F) {
632   for (auto I = F->arg_begin(), E = F->arg_end(); I != E; ++I) {
633     DEBUG(dbgs() << "[hasArrayArg] " << *I << '\n');
634     if (I->getType()->isArrayTy()) {
635       return true;
636     }
637   }
638   return false;
639 }
640 
641 CallInst *
mutateCallInst(Module * M,CallInst * CI,std::function<std::string (CallInst *,std::vector<Value * > &)> ArgMutate,BuiltinFuncMangleInfo * Mangle,AttributeSet * Attrs,bool TakeFuncName)642 mutateCallInst(Module *M, CallInst *CI,
643     std::function<std::string (CallInst *, std::vector<Value *> &)>ArgMutate,
644     BuiltinFuncMangleInfo *Mangle, AttributeSet *Attrs, bool TakeFuncName) {
645   DEBUG(dbgs() << "[mutateCallInst] " << *CI);
646 
647   auto Args = getArguments(CI);
648   auto NewName = ArgMutate(CI, Args);
649   std::string InstName;
650   if (!CI->getType()->isVoidTy() && CI->hasName()) {
651     InstName = CI->getName();
652     CI->setName(InstName + ".old");
653   }
654   auto NewCI = addCallInst(M, NewName, CI->getType(), Args, Attrs, CI, Mangle,
655       InstName, TakeFuncName);
656   DEBUG(dbgs() << " => " << *NewCI << '\n');
657   CI->replaceAllUsesWith(NewCI);
658   CI->dropAllReferences();
659   CI->removeFromParent();
660   return NewCI;
661 }
662 
663 Instruction *
mutateCallInst(Module * M,CallInst * CI,std::function<std::string (CallInst *,std::vector<Value * > &,Type * & RetTy)> ArgMutate,std::function<Instruction * (CallInst *)> RetMutate,BuiltinFuncMangleInfo * Mangle,AttributeSet * Attrs,bool TakeFuncName)664 mutateCallInst(Module *M, CallInst *CI,
665     std::function<std::string (CallInst *, std::vector<Value *> &,
666         Type *&RetTy)>ArgMutate,
667     std::function<Instruction *(CallInst *)> RetMutate,
668     BuiltinFuncMangleInfo *Mangle, AttributeSet *Attrs, bool TakeFuncName) {
669   DEBUG(dbgs() << "[mutateCallInst] " << *CI);
670 
671   auto Args = getArguments(CI);
672   Type *RetTy = CI->getType();
673   auto NewName = ArgMutate(CI, Args, RetTy);
674   std::string InstName;
675   if (CI->hasName()) {
676     InstName = CI->getName();
677     CI->setName(InstName + ".old");
678   }
679   auto NewCI = addCallInst(M, NewName, RetTy, Args, Attrs,
680       CI, Mangle, InstName + ".tmp", TakeFuncName);
681   auto NewI = RetMutate(NewCI);
682   NewI->takeName(CI);
683   DEBUG(dbgs() << " => " << *NewI << '\n');
684   CI->replaceAllUsesWith(NewI);
685   CI->dropAllReferences();
686   CI->removeFromParent();
687   return NewI;
688 }
689 
690 void
mutateFunction(Function * F,std::function<std::string (CallInst *,std::vector<Value * > &)> ArgMutate,BuiltinFuncMangleInfo * Mangle,AttributeSet * Attrs,bool TakeFuncName)691 mutateFunction(Function *F,
692     std::function<std::string (CallInst *, std::vector<Value *> &)>ArgMutate,
693     BuiltinFuncMangleInfo *Mangle, AttributeSet *Attrs,
694     bool TakeFuncName) {
695   auto M = F->getParent();
696   for (auto I = F->user_begin(), E = F->user_end(); I != E;) {
697     if (auto CI = dyn_cast<CallInst>(*I++))
698       mutateCallInst(M, CI, ArgMutate, Mangle, Attrs, TakeFuncName);
699   }
700   if (F->use_empty())
701     F->eraseFromParent();
702 }
703 
704 CallInst *
mutateCallInstSPIRV(Module * M,CallInst * CI,std::function<std::string (CallInst *,std::vector<Value * > &)> ArgMutate,AttributeSet * Attrs)705 mutateCallInstSPIRV(Module *M, CallInst *CI,
706     std::function<std::string (CallInst *, std::vector<Value *> &)>ArgMutate,
707     AttributeSet *Attrs) {
708   BuiltinFuncMangleInfo BtnInfo;
709   return mutateCallInst(M, CI, ArgMutate, &BtnInfo, Attrs);
710 }
711 
712 Instruction *
mutateCallInstSPIRV(Module * M,CallInst * CI,std::function<std::string (CallInst *,std::vector<Value * > &,Type * & RetTy)> ArgMutate,std::function<Instruction * (CallInst *)> RetMutate,AttributeSet * Attrs)713 mutateCallInstSPIRV(Module *M, CallInst *CI,
714     std::function<std::string (CallInst *, std::vector<Value *> &,
715         Type *&RetTy)> ArgMutate,
716     std::function<Instruction *(CallInst *)> RetMutate,
717     AttributeSet *Attrs) {
718   BuiltinFuncMangleInfo BtnInfo;
719   return mutateCallInst(M, CI, ArgMutate, RetMutate, &BtnInfo, Attrs);
720 }
721 
722 CallInst *
addCallInst(Module * M,StringRef FuncName,Type * RetTy,ArrayRef<Value * > Args,AttributeSet * Attrs,Instruction * Pos,BuiltinFuncMangleInfo * Mangle,StringRef InstName,bool TakeFuncName)723 addCallInst(Module *M, StringRef FuncName, Type *RetTy, ArrayRef<Value *> Args,
724     AttributeSet *Attrs, Instruction *Pos, BuiltinFuncMangleInfo *Mangle,
725     StringRef InstName, bool TakeFuncName) {
726 
727   auto F = getOrCreateFunction(M, RetTy, getTypes(Args),
728       FuncName, Mangle, Attrs, TakeFuncName);
729   // Cannot assign a name to void typed values
730   auto CI = CallInst::Create(F, Args, RetTy->isVoidTy() ? "" : InstName, Pos);
731   CI->setCallingConv(F->getCallingConv());
732   return CI;
733 }
734 
735 CallInst *
addCallInstSPIRV(Module * M,StringRef FuncName,Type * RetTy,ArrayRef<Value * > Args,AttributeSet * Attrs,Instruction * Pos,StringRef InstName)736 addCallInstSPIRV(Module *M, StringRef FuncName, Type *RetTy, ArrayRef<Value *> Args,
737     AttributeSet *Attrs, Instruction *Pos, StringRef InstName) {
738   BuiltinFuncMangleInfo BtnInfo;
739   return addCallInst(M, FuncName, RetTy, Args, Attrs, Pos, &BtnInfo,
740       InstName);
741 }
742 
743 bool
isValidVectorSize(unsigned I)744 isValidVectorSize(unsigned I) {
745   return I == 2 ||
746          I == 3 ||
747          I == 4 ||
748          I == 8 ||
749          I == 16;
750 }
751 
752 Value *
addVector(Instruction * InsPos,ValueVecRange Range)753 addVector(Instruction *InsPos, ValueVecRange Range) {
754   size_t VecSize = Range.second - Range.first;
755   if (VecSize == 1)
756     return *Range.first;
757   assert(isValidVectorSize(VecSize) && "Invalid vector size");
758   IRBuilder<> Builder(InsPos);
759   auto Vec = Builder.CreateVectorSplat(VecSize, *Range.first);
760   unsigned Index = 1;
761   for (++Range.first; Range.first != Range.second; ++Range.first, ++Index)
762     Vec = Builder.CreateInsertElement(Vec, *Range.first,
763         ConstantInt::get(Type::getInt32Ty(InsPos->getContext()), Index, false));
764   return Vec;
765 }
766 
767 void
makeVector(Instruction * InsPos,std::vector<Value * > & Ops,ValueVecRange Range)768 makeVector(Instruction *InsPos, std::vector<Value *> &Ops,
769     ValueVecRange Range) {
770   auto Vec = addVector(InsPos, Range);
771   Ops.erase(Range.first, Range.second);
772   Ops.push_back(Vec);
773 }
774 
775 void
expandVector(Instruction * InsPos,std::vector<Value * > & Ops,size_t VecPos)776 expandVector(Instruction *InsPos, std::vector<Value *> &Ops,
777     size_t VecPos) {
778   auto Vec = Ops[VecPos];
779   auto VT = Vec->getType();
780   if (!VT->isVectorTy())
781     return;
782   size_t N = VT->getVectorNumElements();
783   IRBuilder<> Builder(InsPos);
784   for (size_t I = 0; I != N; ++I)
785     Ops.insert(Ops.begin() + VecPos + I, Builder.CreateExtractElement(Vec,
786         ConstantInt::get(Type::getInt32Ty(InsPos->getContext()), I, false)));
787   Ops.erase(Ops.begin() + VecPos + N);
788 }
789 
790 Constant *
castToInt8Ptr(Constant * V,unsigned Addr=0)791 castToInt8Ptr(Constant *V, unsigned Addr = 0) {
792   return ConstantExpr::getBitCast(V, Type::getInt8PtrTy(V->getContext(), Addr));
793 }
794 
795 PointerType *
getInt8PtrTy(PointerType * T)796 getInt8PtrTy(PointerType *T) {
797   return Type::getInt8PtrTy(T->getContext(), T->getAddressSpace());
798 }
799 
800 Value *
castToInt8Ptr(Value * V,Instruction * Pos)801 castToInt8Ptr(Value *V, Instruction *Pos) {
802   return CastInst::CreatePointerCast(V, getInt8PtrTy(
803       cast<PointerType>(V->getType())), "", Pos);
804 }
805 
806 CallInst *
addBlockBind(Module * M,Function * InvokeFunc,Value * BlkCtx,Value * CtxLen,Value * CtxAlign,Instruction * InsPos,StringRef InstName)807 addBlockBind(Module *M, Function *InvokeFunc, Value *BlkCtx, Value *CtxLen,
808     Value *CtxAlign, Instruction *InsPos, StringRef InstName) {
809   auto BlkTy = getOrCreateOpaquePtrType(M, SPIR_TYPE_NAME_BLOCK_T,
810       SPIRAS_Private);
811   auto &Ctx = M->getContext();
812   Value *BlkArgs[] = {
813       castToInt8Ptr(InvokeFunc),
814       CtxLen ? CtxLen : UndefValue::get(Type::getInt32Ty(Ctx)),
815       CtxAlign ? CtxAlign : UndefValue::get(Type::getInt32Ty(Ctx)),
816       BlkCtx ? BlkCtx : UndefValue::get(Type::getInt8PtrTy(Ctx))
817   };
818   return addCallInst(M, SPIR_INTRINSIC_BLOCK_BIND, BlkTy, BlkArgs, nullptr,
819       InsPos, nullptr, InstName);
820 }
821 
getSizetType(Module * M)822 IntegerType* getSizetType(Module *M) {
823   return IntegerType::getIntNTy(M->getContext(),
824     M->getDataLayout().getPointerSizeInBits(0));
825 }
826 
827 Type *
getVoidFuncType(Module * M)828 getVoidFuncType(Module *M) {
829   return FunctionType::get(Type::getVoidTy(M->getContext()), false);
830 }
831 
832 Type *
getVoidFuncPtrType(Module * M,unsigned AddrSpace)833 getVoidFuncPtrType(Module *M, unsigned AddrSpace) {
834   return PointerType::get(getVoidFuncType(M), AddrSpace);
835 }
836 
837 ConstantInt *
getInt64(Module * M,int64_t value)838 getInt64(Module *M, int64_t value) {
839   return ConstantInt::get(Type::getInt64Ty(M->getContext()), value, true);
840 }
841 
getFloat32(Module * M,float value)842 Constant *getFloat32(Module *M, float value) {
843   return ConstantFP::get(Type::getFloatTy(M->getContext()), value);
844 }
845 
846 ConstantInt *
getInt32(Module * M,int value)847 getInt32(Module *M, int value) {
848   return ConstantInt::get(Type::getInt32Ty(M->getContext()), value, true);
849 }
850 
851 ConstantInt *
getUInt32(Module * M,unsigned value)852 getUInt32(Module *M, unsigned value) {
853   return ConstantInt::get(Type::getInt32Ty(M->getContext()), value, false);
854 }
855 
856 ConstantInt *
getUInt16(Module * M,unsigned short value)857 getUInt16(Module *M, unsigned short value) {
858   return ConstantInt::get(Type::getInt16Ty(M->getContext()), value, false);
859 }
860 
getInt32(Module * M,const std::vector<int> & value)861 std::vector<Value *> getInt32(Module *M, const std::vector<int> &value) {
862   std::vector<Value *> V;
863   for (auto &I:value)
864     V.push_back(getInt32(M, I));
865   return V;
866 }
867 
868 ConstantInt *
getSizet(Module * M,uint64_t value)869 getSizet(Module *M, uint64_t value) {
870   return ConstantInt::get(getSizetType(M), value, false);
871 }
872 
873 ///////////////////////////////////////////////////////////////////////////////
874 //
875 // Functions for getting metadata
876 //
877 ///////////////////////////////////////////////////////////////////////////////
878 int
getMDOperandAsInt(MDNode * N,unsigned I)879 getMDOperandAsInt(MDNode* N, unsigned I) {
880   return mdconst::dyn_extract<ConstantInt>(N->getOperand(I))->getZExtValue();
881 }
882 
883 std::string
getMDOperandAsString(MDNode * N,unsigned I)884 getMDOperandAsString(MDNode* N, unsigned I) {
885   if (!N)
886     return "";
887 
888   Metadata* Op = N->getOperand(I);
889   if (!Op)
890     return "";
891 
892   if (MDString* Str = dyn_cast<MDString>(Op)) {
893     return Str->getString().str();
894   }
895   return "";
896 }
897 
898 Type*
getMDOperandAsType(MDNode * N,unsigned I)899 getMDOperandAsType(MDNode* N, unsigned I) {
900   return cast<ValueAsMetadata>(N->getOperand(I))->getType();
901 }
902 
903 std::set<std::string>
getNamedMDAsStringSet(Module * M,const std::string & MDName)904 getNamedMDAsStringSet(Module *M, const std::string &MDName) {
905   NamedMDNode *NamedMD = M->getNamedMetadata(MDName);
906   std::set<std::string> StrSet;
907   if (!NamedMD)
908     return StrSet;
909 
910   assert(NamedMD->getNumOperands() > 0 && "Invalid SPIR");
911 
912   for (unsigned I = 0, E = NamedMD->getNumOperands(); I != E; ++I) {
913     MDNode *MD = NamedMD->getOperand(I);
914     if (!MD || MD->getNumOperands() == 0)
915       continue;
916     for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J)
917       StrSet.insert(getMDOperandAsString(MD, J));
918   }
919 
920   return StrSet;
921 }
922 
923 std::tuple<unsigned, unsigned, std::string>
getSPIRVSource(Module * M)924 getSPIRVSource(Module *M) {
925   std::tuple<unsigned, unsigned, std::string> Tup;
926   if (auto N = SPIRVMDWalker(*M).getNamedMD(kSPIRVMD::Source).nextOp())
927     N.get(std::get<0>(Tup))
928      .get(std::get<1>(Tup))
929      .setQuiet(true)
930      .get(std::get<2>(Tup));
931   return Tup;
932 }
933 
mapUInt(Module * M,ConstantInt * I,std::function<unsigned (unsigned)> F)934 ConstantInt *mapUInt(Module *M, ConstantInt *I,
935     std::function<unsigned(unsigned)> F) {
936   return ConstantInt::get(I->getType(), F(I->getZExtValue()), false);
937 }
938 
mapSInt(Module * M,ConstantInt * I,std::function<int (int)> F)939 ConstantInt *mapSInt(Module *M, ConstantInt *I,
940     std::function<int(int)> F) {
941   return ConstantInt::get(I->getType(), F(I->getSExtValue()), true);
942 }
943 
944 bool
isDecoratedSPIRVFunc(const Function * F,std::string * UndecoratedName)945 isDecoratedSPIRVFunc(const Function *F, std::string *UndecoratedName) {
946   if (!F->hasName() || !F->getName().startswith(kSPIRVName::Prefix))
947     return false;
948   if (UndecoratedName)
949     *UndecoratedName = undecorateSPIRVFunction(F->getName());
950   return true;
951 }
952 
953 /// Get TypePrimitiveEnum for special OpenCL type except opencl.block.
954 SPIR::TypePrimitiveEnum
getOCLTypePrimitiveEnum(StringRef TyName)955 getOCLTypePrimitiveEnum(StringRef TyName) {
956   return StringSwitch<SPIR::TypePrimitiveEnum>(TyName)
957     .Case("opencl.image1d_t",         SPIR::PRIMITIVE_IMAGE_1D_T)
958     .Case("opencl.image1d_array_t",   SPIR::PRIMITIVE_IMAGE_1D_ARRAY_T)
959     .Case("opencl.image1d_buffer_t",  SPIR::PRIMITIVE_IMAGE_1D_BUFFER_T)
960     .Case("opencl.image2d_t",         SPIR::PRIMITIVE_IMAGE_2D_T)
961     .Case("opencl.image2d_array_t",   SPIR::PRIMITIVE_IMAGE_2D_ARRAY_T)
962     .Case("opencl.image3d_t",         SPIR::PRIMITIVE_IMAGE_3D_T)
963     .Case("opencl.image2d_msaa_t",    SPIR::PRIMITIVE_IMAGE_2D_MSAA_T)
964     .Case("opencl.image2d_array_msaa_t",        SPIR::PRIMITIVE_IMAGE_2D_ARRAY_MSAA_T)
965     .Case("opencl.image2d_msaa_depth_t",        SPIR::PRIMITIVE_IMAGE_2D_MSAA_DEPTH_T)
966     .Case("opencl.image2d_array_msaa_depth_t",  SPIR::PRIMITIVE_IMAGE_2D_ARRAY_MSAA_DEPTH_T)
967     .Case("opencl.image2d_depth_t",             SPIR::PRIMITIVE_IMAGE_2D_DEPTH_T)
968     .Case("opencl.image2d_array_depth_t",       SPIR::PRIMITIVE_IMAGE_2D_ARRAY_DEPTH_T)
969     .Case("opencl.event_t",           SPIR::PRIMITIVE_EVENT_T)
970     .Case("opencl.pipe_t",            SPIR::PRIMITIVE_PIPE_T)
971     .Case("opencl.reserve_id_t",      SPIR::PRIMITIVE_RESERVE_ID_T)
972     .Case("opencl.queue_t",           SPIR::PRIMITIVE_QUEUE_T)
973     .Case("opencl.clk_event_t",       SPIR::PRIMITIVE_CLK_EVENT_T)
974     .Case("opencl.sampler_t",         SPIR::PRIMITIVE_SAMPLER_T)
975     .Case("struct.ndrange_t",         SPIR::PRIMITIVE_NDRANGE_T)
976     .Default(                         SPIR::PRIMITIVE_NONE);
977 }
978 /// Translates LLVM type to descriptor for mangler.
979 /// \param Signed indicates integer type should be translated as signed.
980 /// \param VoidPtr indicates i8* should be translated as void*.
981 static SPIR::RefParamType
transTypeDesc(Type * Ty,const BuiltinArgTypeMangleInfo & Info)982 transTypeDesc(Type *Ty, const BuiltinArgTypeMangleInfo &Info) {
983   bool Signed = Info.IsSigned;
984   unsigned Attr = Info.Attr;
985   bool VoidPtr = Info.IsVoidPtr;
986   if (Info.IsEnum)
987     return SPIR::RefParamType(new SPIR::PrimitiveType(Info.Enum));
988   if (Info.IsSampler)
989     return SPIR::RefParamType(new SPIR::PrimitiveType(
990         SPIR::PRIMITIVE_SAMPLER_T));
991   if (Info.IsAtomic && !Ty->isPointerTy()) {
992     BuiltinArgTypeMangleInfo DTInfo = Info;
993     DTInfo.IsAtomic = false;
994     return SPIR::RefParamType(new SPIR::AtomicType(
995         transTypeDesc(Ty, DTInfo)));
996   }
997   if(auto *IntTy = dyn_cast<IntegerType>(Ty)) {
998     switch(IntTy->getBitWidth()) {
999     case 1:
1000       return SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_BOOL));
1001     case 8:
1002       return SPIR::RefParamType(new SPIR::PrimitiveType(Signed?
1003           SPIR::PRIMITIVE_CHAR:SPIR::PRIMITIVE_UCHAR));
1004     case 16:
1005       return SPIR::RefParamType(new SPIR::PrimitiveType(Signed?
1006           SPIR::PRIMITIVE_SHORT:SPIR::PRIMITIVE_USHORT));
1007     case 32:
1008       return SPIR::RefParamType(new SPIR::PrimitiveType(Signed?
1009           SPIR::PRIMITIVE_INT:SPIR::PRIMITIVE_UINT));
1010     case 64:
1011       return SPIR::RefParamType(new SPIR::PrimitiveType(Signed?
1012           SPIR::PRIMITIVE_LONG:SPIR::PRIMITIVE_ULONG));
1013     default:
1014       llvm_unreachable("invliad int size");
1015     }
1016   }
1017   if (Ty->isVoidTy())
1018     return SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_VOID));
1019   if (Ty->isHalfTy())
1020     return SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_HALF));
1021   if (Ty->isFloatTy())
1022     return SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_FLOAT));
1023   if (Ty->isDoubleTy())
1024     return SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_DOUBLE));
1025   if (Ty->isVectorTy()) {
1026     return SPIR::RefParamType(new SPIR::VectorType(
1027         transTypeDesc(Ty->getVectorElementType(), Info),
1028         Ty->getVectorNumElements()));
1029   }
1030   if (Ty->isArrayTy()) {
1031     return transTypeDesc(PointerType::get(Ty->getArrayElementType(), 0), Info);
1032   }
1033   if (Ty->isStructTy()) {
1034     auto Name = Ty->getStructName();
1035     std::string Tmp;
1036 
1037     if (Name.startswith(kLLVMTypeName::StructPrefix))
1038       Name = Name.drop_front(strlen(kLLVMTypeName::StructPrefix));
1039     if (Name.startswith(kSPIRVTypeName::PrefixAndDelim)) {
1040       Name = Name.substr(sizeof(kSPIRVTypeName::PrefixAndDelim) - 1);
1041       Tmp = Name.str();
1042       auto pos = Tmp.find(kSPIRVTypeName::Delimiter); //first dot
1043       while (pos != std::string::npos) {
1044         Tmp[pos] = '_';
1045         pos = Tmp.find(kSPIRVTypeName::Delimiter, pos);
1046       }
1047       Name = Tmp = kSPIRVName::Prefix + Tmp;
1048     }
1049     // ToDo: Create a better unique name for struct without name
1050     if (Name.empty()) {
1051       std::ostringstream OS;
1052       OS << reinterpret_cast<size_t>(Ty);
1053       Name = Tmp = std::string("struct_") + OS.str();
1054     }
1055     return SPIR::RefParamType(new SPIR::UserDefinedType(Name));
1056   }
1057 
1058   if (Ty->isPointerTy()) {
1059     auto ET = Ty->getPointerElementType();
1060     SPIR::ParamType *EPT = nullptr;
1061     if (auto FT = dyn_cast<FunctionType>(ET)) {
1062       (void) FT;
1063       assert(isVoidFuncTy(FT) && "Not supported");
1064       EPT = new SPIR::BlockType;
1065     } else if (auto StructTy = dyn_cast<StructType>(ET)) {
1066       DEBUG(dbgs() << "ptr to struct: " << *Ty << '\n');
1067       auto TyName = StructTy->getStructName();
1068       if (TyName.startswith(kSPR2TypeName::ImagePrefix) ||
1069           TyName.startswith(kSPR2TypeName::Pipe)) {
1070         auto DelimPos = TyName.find_first_of(kSPR2TypeName::Delimiter,
1071             strlen(kSPR2TypeName::OCLPrefix));
1072         if (DelimPos != StringRef::npos)
1073           TyName = TyName.substr(0, DelimPos);
1074       }
1075       DEBUG(dbgs() << "  type name: " << TyName << '\n');
1076 
1077       auto Prim = getOCLTypePrimitiveEnum(TyName);
1078       if (StructTy->isOpaque()) {
1079         if (TyName == "opencl.block") {
1080           auto BlockTy = new SPIR::BlockType;
1081           // Handle block with local memory arguments according to OpenCL 2.0 spec.
1082           if(Info.IsLocalArgBlock) {
1083             SPIR::RefParamType VoidTyRef(new SPIR::PrimitiveType(SPIR::PRIMITIVE_VOID));
1084             auto VoidPtrTy = new SPIR::PointerType(VoidTyRef);
1085             VoidPtrTy->setAddressSpace(SPIR::ATTR_LOCAL);
1086             // "__local void *"
1087             BlockTy->setParam(0, SPIR::RefParamType(VoidPtrTy));
1088             // "..."
1089             BlockTy->setParam(1, SPIR::RefParamType(
1090               new SPIR::PrimitiveType(SPIR::PRIMITIVE_VAR_ARG)));
1091           }
1092           EPT = BlockTy;
1093         } else if (Prim != SPIR::PRIMITIVE_NONE) {
1094           if (Prim == SPIR::PRIMITIVE_PIPE_T) {
1095             SPIR::RefParamType OpaqueTyRef(new SPIR::PrimitiveType(Prim));
1096             auto OpaquePtrTy = new SPIR::PointerType(OpaqueTyRef);
1097             OpaquePtrTy->setAddressSpace(getOCLOpaqueTypeAddrSpace(Prim));
1098             EPT = OpaquePtrTy;
1099           }
1100           else {
1101             EPT = new SPIR::PrimitiveType(Prim);
1102           }
1103         }
1104       } else if (Prim == SPIR::PRIMITIVE_NDRANGE_T)
1105         // ndrange_t is not opaque type
1106         EPT = new SPIR::PrimitiveType(SPIR::PRIMITIVE_NDRANGE_T);
1107     }
1108     if (EPT)
1109       return SPIR::RefParamType(EPT);
1110 
1111     if (VoidPtr && ET->isIntegerTy(8))
1112       ET = Type::getVoidTy(ET->getContext());
1113     auto PT = new SPIR::PointerType(transTypeDesc(ET, Info));
1114     PT->setAddressSpace(static_cast<SPIR::TypeAttributeEnum>(
1115       Ty->getPointerAddressSpace() + (unsigned)SPIR::ATTR_ADDR_SPACE_FIRST));
1116     for (unsigned I = SPIR::ATTR_QUALIFIER_FIRST,
1117         E = SPIR::ATTR_QUALIFIER_LAST; I <= E; ++I)
1118       PT->setQualifier(static_cast<SPIR::TypeAttributeEnum>(I), I & Attr);
1119     return SPIR::RefParamType(PT);
1120   }
1121   DEBUG(dbgs() << "[transTypeDesc] " << *Ty << '\n');
1122   assert (0 && "not implemented");
1123   return SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_INT));
1124 }
1125 
1126 Value *
getScalarOrArray(Value * V,unsigned Size,Instruction * Pos)1127 getScalarOrArray(Value *V, unsigned Size, Instruction *Pos) {
1128   if (!V->getType()->isPointerTy())
1129     return V;
1130   assert((isa<ConstantExpr>(V) || isa<GetElementPtrInst>(V)) &&
1131          "unexpected value type");
1132   auto GEP = cast<User>(V);
1133   assert(GEP->getNumOperands() == 3 && "must be a GEP from an array");
1134   auto P = GEP->getOperand(0);
1135   assert(P->getType()->getPointerElementType()->getArrayNumElements() == Size);
1136   auto Index0 = GEP->getOperand(1);
1137   (void) Index0;
1138   assert(dyn_cast<ConstantInt>(Index0)->getZExtValue() == 0);
1139   auto Index1 = GEP->getOperand(2);
1140   (void) Index1;
1141   assert(dyn_cast<ConstantInt>(Index1)->getZExtValue() == 0);
1142   return new LoadInst(P, "", Pos);
1143 }
1144 
1145 Constant *
getScalarOrVectorConstantInt(Type * T,uint64_t V,bool isSigned)1146 getScalarOrVectorConstantInt(Type *T, uint64_t V, bool isSigned) {
1147   if (auto IT = dyn_cast<IntegerType>(T))
1148     return ConstantInt::get(IT, V);
1149   if (auto VT = dyn_cast<VectorType>(T)) {
1150     std::vector<Constant *> EV(VT->getVectorNumElements(),
1151         getScalarOrVectorConstantInt(VT->getVectorElementType(), V, isSigned));
1152     return ConstantVector::get(EV);
1153   }
1154   llvm_unreachable("Invalid type");
1155   return nullptr;
1156 }
1157 
1158 Value *
getScalarOrArrayConstantInt(Instruction * Pos,Type * T,unsigned Len,uint64_t V,bool isSigned)1159 getScalarOrArrayConstantInt(Instruction *Pos, Type *T, unsigned Len, uint64_t V,
1160     bool isSigned) {
1161   if (auto IT = dyn_cast<IntegerType>(T)) {
1162     assert(Len == 1 && "Invalid length");
1163     return ConstantInt::get(IT, V, isSigned);
1164   }
1165   if (auto PT = dyn_cast<PointerType>(T)) {
1166     auto ET = PT->getPointerElementType();
1167     auto AT = ArrayType::get(ET, Len);
1168     std::vector<Constant *> EV(Len, ConstantInt::get(ET, V, isSigned));
1169     auto CA = ConstantArray::get(AT, EV);
1170     auto Alloca = new AllocaInst(AT, "", Pos);
1171     new StoreInst(CA, Alloca, Pos);
1172     auto Zero = ConstantInt::getNullValue(Type::getInt32Ty(T->getContext()));
1173     Value *Index[] = {Zero, Zero};
1174     auto Ret = GetElementPtrInst::CreateInBounds(Alloca, Index, "", Pos);
1175     DEBUG(dbgs() << "[getScalarOrArrayConstantInt] Alloca: " <<
1176         *Alloca << ", Return: " << *Ret << '\n');
1177     return Ret;
1178   }
1179   if (auto AT = dyn_cast<ArrayType>(T)) {
1180     auto ET = AT->getArrayElementType();
1181     assert(AT->getArrayNumElements() == Len);
1182     std::vector<Constant *> EV(Len, ConstantInt::get(ET, V, isSigned));
1183     auto Ret = ConstantArray::get(AT, EV);
1184     DEBUG(dbgs() << "[getScalarOrArrayConstantInt] Array type: " <<
1185         *AT << ", Return: " << *Ret << '\n');
1186     return Ret;
1187   }
1188   llvm_unreachable("Invalid type");
1189   return nullptr;
1190 }
1191 
1192 void
dumpUsers(Value * V,StringRef Prompt)1193 dumpUsers(Value* V, StringRef Prompt) {
1194   if (!V) return;
1195   DEBUG(dbgs() << Prompt << " Users of " << *V << " :\n");
1196   for (auto UI = V->user_begin(), UE = V->user_end(); UI != UE; ++UI)
1197     DEBUG(dbgs() << "  " << **UI << '\n');
1198 }
1199 
1200 std::string
getSPIRVTypeName(StringRef BaseName,StringRef Postfixes)1201 getSPIRVTypeName(StringRef BaseName, StringRef Postfixes) {
1202   assert(!BaseName.empty() && "Invalid SPIR-V type name");
1203   auto TN = std::string(kSPIRVTypeName::PrefixAndDelim)
1204     + BaseName.str();
1205   if (Postfixes.empty())
1206     return TN;
1207   return TN + kSPIRVTypeName::Delimiter + Postfixes.str();
1208 }
1209 
1210 bool
isSPIRVConstantName(StringRef TyName)1211 isSPIRVConstantName(StringRef TyName) {
1212   if (TyName == getSPIRVTypeName(kSPIRVTypeName::ConstantSampler) ||
1213       TyName == getSPIRVTypeName(kSPIRVTypeName::ConstantPipeStorage))
1214     return true;
1215 
1216   return false;
1217 }
1218 
1219 Type *
getSPIRVTypeByChangeBaseTypeName(Module * M,Type * T,StringRef OldName,StringRef NewName)1220 getSPIRVTypeByChangeBaseTypeName(Module *M, Type *T, StringRef OldName,
1221     StringRef NewName) {
1222   StringRef Postfixes;
1223   if (isSPIRVType(T, OldName, &Postfixes))
1224     return getOrCreateOpaquePtrType(M, getSPIRVTypeName(NewName, Postfixes));
1225   DEBUG(dbgs() << " Invalid SPIR-V type " << *T << '\n');
1226   llvm_unreachable("Invalid SPIRV-V type");
1227   return nullptr;
1228 }
1229 
1230 std::string
getSPIRVImageTypePostfixes(StringRef SampledType,SPIRVTypeImageDescriptor Desc,SPIRVAccessQualifierKind Acc)1231 getSPIRVImageTypePostfixes(StringRef SampledType,
1232     SPIRVTypeImageDescriptor Desc,
1233     SPIRVAccessQualifierKind Acc) {
1234   std::string S;
1235   raw_string_ostream OS(S);
1236   OS << SampledType << kSPIRVTypeName::PostfixDelim
1237      << Desc.Dim << kSPIRVTypeName::PostfixDelim
1238      << Desc.Depth << kSPIRVTypeName::PostfixDelim
1239      << Desc.Arrayed << kSPIRVTypeName::PostfixDelim
1240      << Desc.MS << kSPIRVTypeName::PostfixDelim
1241      << Desc.Sampled << kSPIRVTypeName::PostfixDelim
1242      << Desc.Format << kSPIRVTypeName::PostfixDelim
1243      << Acc;
1244   return OS.str();
1245 }
1246 
1247 std::string
getSPIRVImageSampledTypeName(SPIRVType * Ty)1248 getSPIRVImageSampledTypeName(SPIRVType *Ty) {
1249   switch(Ty->getOpCode()) {
1250   case OpTypeVoid:
1251     return kSPIRVImageSampledTypeName::Void;
1252   case OpTypeInt:
1253     if (Ty->getIntegerBitWidth() == 32) {
1254       if (static_cast<SPIRVTypeInt *>(Ty)->isSigned())
1255         return kSPIRVImageSampledTypeName::Int;
1256       else
1257         return kSPIRVImageSampledTypeName::UInt; }
1258     break;
1259   case OpTypeFloat:
1260     switch(Ty->getFloatBitWidth()) {
1261     case 16:
1262       return kSPIRVImageSampledTypeName::Half;
1263     case 32:
1264       return kSPIRVImageSampledTypeName::Float;
1265     default:
1266       break;
1267     }
1268     break;
1269   default:
1270     break;
1271   }
1272   llvm_unreachable("Invalid sampled type for image");
1273 }
1274 
1275 //ToDo: Find a way to represent uint sampled type in LLVM, maybe an
1276 //      opaque type.
1277 Type*
getLLVMTypeForSPIRVImageSampledTypePostfix(StringRef Postfix,LLVMContext & Ctx)1278 getLLVMTypeForSPIRVImageSampledTypePostfix(StringRef Postfix,
1279   LLVMContext &Ctx) {
1280   if (Postfix == kSPIRVImageSampledTypeName::Void)
1281     return Type::getVoidTy(Ctx);
1282   if (Postfix == kSPIRVImageSampledTypeName::Float)
1283     return Type::getFloatTy(Ctx);
1284   if (Postfix == kSPIRVImageSampledTypeName::Half)
1285     return Type::getHalfTy(Ctx);
1286   if (Postfix == kSPIRVImageSampledTypeName::Int ||
1287       Postfix == kSPIRVImageSampledTypeName::UInt)
1288     return Type::getInt32Ty(Ctx);
1289   llvm_unreachable("Invalid sampled type postfix");
1290 }
1291 
1292 std::string
mapOCLTypeNameToSPIRV(StringRef Name,StringRef Acc)1293 mapOCLTypeNameToSPIRV(StringRef Name, StringRef Acc) {
1294   std::string BaseTy;
1295   std::string Postfixes;
1296   raw_string_ostream OS(Postfixes);
1297   if (!Acc.empty())
1298     OS << kSPIRVTypeName::PostfixDelim;
1299   if (Name.startswith(kSPR2TypeName::Pipe)) {
1300     BaseTy = kSPIRVTypeName::Pipe;
1301     OS << SPIRSPIRVAccessQualifierMap::map(Acc);
1302   } else if (Name.startswith(kSPR2TypeName::ImagePrefix)) {
1303     SmallVector<StringRef, 4> SubStrs;
1304     const char Delims[] = {kSPR2TypeName::Delimiter, 0};
1305     Name.split(SubStrs, Delims);
1306     std::string ImageTyName = SubStrs[1].str();
1307     if (hasAccessQualifiedName(Name))
1308       ImageTyName.erase(ImageTyName.size() - 5, 3);
1309     auto Desc = map<SPIRVTypeImageDescriptor>(ImageTyName);
1310     DEBUG(dbgs() << "[trans image type] " << SubStrs[1] << " => " <<
1311         "(" << (unsigned)Desc.Dim << ", " <<
1312                Desc.Depth << ", " <<
1313                Desc.Arrayed << ", " <<
1314                Desc.MS << ", " <<
1315                Desc.Sampled << ", " <<
1316                Desc.Format << ")\n");
1317 
1318     BaseTy = kSPIRVTypeName::Image;
1319     OS << getSPIRVImageTypePostfixes(kSPIRVImageSampledTypeName::Void,
1320                                      Desc,
1321                                      SPIRSPIRVAccessQualifierMap::map(Acc));
1322   } else {
1323     DEBUG(dbgs() << "Mapping of " << Name << " is not implemented\n");
1324     llvm_unreachable("Not implemented");
1325   }
1326   return getSPIRVTypeName(BaseTy, OS.str());
1327 }
1328 
1329 bool
eraseIfNoUse(Function * F)1330 eraseIfNoUse(Function *F) {
1331   bool changed = false;
1332   if (!F)
1333     return changed;
1334   if (!GlobalValue::isInternalLinkage(F->getLinkage()) &&
1335       !F->isDeclaration())
1336     return changed;
1337 
1338   dumpUsers(F, "[eraseIfNoUse] ");
1339   for (auto UI = F->user_begin(), UE = F->user_end(); UI != UE;) {
1340     auto U = *UI++;
1341     if (auto CE = dyn_cast<ConstantExpr>(U)){
1342       if (CE->use_empty()) {
1343         CE->dropAllReferences();
1344         changed = true;
1345       }
1346     }
1347   }
1348   if (F->use_empty()) {
1349     DEBUG(dbgs() << "Erase ";
1350           F->printAsOperand(dbgs());
1351           dbgs() << '\n');
1352     F->eraseFromParent();
1353     changed = true;
1354   }
1355   return changed;
1356 }
1357 
1358 void
eraseIfNoUse(Value * V)1359 eraseIfNoUse(Value *V) {
1360   if (!V->use_empty())
1361     return;
1362   if (Constant *C = dyn_cast<Constant>(V)) {
1363     C->destroyConstant();
1364     return;
1365   }
1366   if (Instruction *I = dyn_cast<Instruction>(V)) {
1367     if (!I->mayHaveSideEffects())
1368       I->eraseFromParent();
1369   }
1370   eraseIfNoUse(dyn_cast<Function>(V));
1371 }
1372 
1373 bool
eraseUselessFunctions(Module * M)1374 eraseUselessFunctions(Module *M) {
1375   bool changed = false;
1376   for (auto I = M->begin(), E = M->end(); I != E;)
1377     changed |= eraseIfNoUse(static_cast<Function*>(I++));
1378   return changed;
1379 }
1380 
1381 std::string
mangleBuiltin(const std::string & UniqName,ArrayRef<Type * > ArgTypes,BuiltinFuncMangleInfo * BtnInfo)1382 mangleBuiltin(const std::string &UniqName,
1383     ArrayRef<Type*> ArgTypes, BuiltinFuncMangleInfo* BtnInfo) {
1384   if (!BtnInfo)
1385     return UniqName;
1386   BtnInfo->init(UniqName);
1387   std::string MangledName;
1388   DEBUG(dbgs() << "[mangle] " << UniqName << " => ");
1389   SPIR::NameMangler Mangler(SPIR::SPIR20);
1390   SPIR::FunctionDescriptor FD;
1391   FD.name = BtnInfo->getUnmangledName();
1392   bool BIVarArgNegative = BtnInfo->getVarArg() < 0;
1393 
1394   if (ArgTypes.empty()) {
1395     // Function signature cannot be ()(void, ...) so if there is an ellipsis
1396     // it must be ()(...)
1397     if(BIVarArgNegative) {
1398       FD.parameters.emplace_back(SPIR::RefParamType(new SPIR::PrimitiveType(
1399         SPIR::PRIMITIVE_VOID)));
1400     }
1401   } else {
1402     for (unsigned I = 0,
1403          E = BIVarArgNegative ? ArgTypes.size() : (unsigned)BtnInfo->getVarArg();
1404          I != E; ++I) {
1405       auto T = ArgTypes[I];
1406       FD.parameters.emplace_back(transTypeDesc(T, BtnInfo->getTypeMangleInfo(I)));
1407     }
1408   }
1409   // Ellipsis must be the last argument of any function
1410   if(!BIVarArgNegative) {
1411     assert((unsigned)BtnInfo->getVarArg() <= ArgTypes.size()
1412            && "invalid index of an ellipsis");
1413     FD.parameters.emplace_back(SPIR::RefParamType(new SPIR::PrimitiveType(
1414         SPIR::PRIMITIVE_VAR_ARG)));
1415   }
1416   Mangler.mangle(FD, MangledName);
1417   DEBUG(dbgs() << MangledName << '\n');
1418   return MangledName;
1419 }
1420 
1421 /// Check if access qualifier is encoded in the type name.
hasAccessQualifiedName(StringRef TyName)1422 bool hasAccessQualifiedName(StringRef TyName) {
1423   if (TyName.endswith("_ro_t") || TyName.endswith("_wo_t") ||
1424       TyName.endswith("_rw_t"))
1425     return true;
1426   return false;
1427 }
1428 
1429 /// Get access qualifier from the type name.
getAccessQualifier(StringRef TyName)1430 StringRef getAccessQualifier(StringRef TyName) {
1431   assert(hasAccessQualifiedName(TyName) &&
1432          "Type is not qualified with access.");
1433   auto Acc = TyName.substr(TyName.size() - 4, 2);
1434   return llvm::StringSwitch<StringRef>(Acc)
1435       .Case("ro", "read_only")
1436       .Case("wo", "write_only")
1437       .Case("rw", "read_write")
1438       .Default("");
1439 }
1440 
1441 /// Translates OpenCL image type names to SPIR-V.
getSPIRVImageTypeFromOCL(Module * M,Type * ImageTy)1442 Type *getSPIRVImageTypeFromOCL(Module *M, Type *ImageTy) {
1443   assert(isOCLImageType(ImageTy) && "Unsupported type");
1444   auto ImageTypeName = ImageTy->getPointerElementType()->getStructName();
1445   std::string Acc = kAccessQualName::ReadOnly;
1446   if (hasAccessQualifiedName(ImageTypeName))
1447     Acc = getAccessQualifier(ImageTypeName);
1448   return getOrCreateOpaquePtrType(M, mapOCLTypeNameToSPIRV(ImageTypeName, Acc));
1449 }
1450 }
1451