1 /*
2  * Copyright 2010-2012, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "slang_backend.h"
18 
19 #include <string>
20 #include <vector>
21 
22 #include "clang/AST/ASTContext.h"
23 #include "clang/AST/Attr.h"
24 #include "clang/AST/Decl.h"
25 #include "clang/AST/DeclGroup.h"
26 
27 #include "clang/Basic/Diagnostic.h"
28 #include "clang/Basic/TargetInfo.h"
29 #include "clang/Basic/TargetOptions.h"
30 
31 #include "clang/CodeGen/ModuleBuilder.h"
32 
33 #include "clang/Frontend/CodeGenOptions.h"
34 #include "clang/Frontend/FrontendDiagnostic.h"
35 
36 #include "llvm/ADT/Twine.h"
37 #include "llvm/ADT/StringExtras.h"
38 
39 #include "llvm/Bitcode/ReaderWriter.h"
40 
41 #include "llvm/CodeGen/RegAllocRegistry.h"
42 #include "llvm/CodeGen/SchedulerRegistry.h"
43 
44 #include "llvm/IR/Constant.h"
45 #include "llvm/IR/Constants.h"
46 #include "llvm/IR/DataLayout.h"
47 #include "llvm/IR/DebugLoc.h"
48 #include "llvm/IR/DerivedTypes.h"
49 #include "llvm/IR/Function.h"
50 #include "llvm/IR/IRBuilder.h"
51 #include "llvm/IR/IRPrintingPasses.h"
52 #include "llvm/IR/LLVMContext.h"
53 #include "llvm/IR/Metadata.h"
54 #include "llvm/IR/Module.h"
55 
56 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
57 
58 #include "llvm/Target/TargetMachine.h"
59 #include "llvm/Target/TargetOptions.h"
60 #include "llvm/Support/TargetRegistry.h"
61 
62 #include "llvm/MC/SubtargetFeature.h"
63 
64 #include "slang_assert.h"
65 #include "slang.h"
66 #include "slang_bitcode_gen.h"
67 #include "slang_rs_context.h"
68 #include "slang_rs_export_foreach.h"
69 #include "slang_rs_export_func.h"
70 #include "slang_rs_export_reduce.h"
71 #include "slang_rs_export_type.h"
72 #include "slang_rs_export_var.h"
73 #include "slang_rs_metadata.h"
74 
75 #include "rs_cc_options.h"
76 
77 #include "strip_unknown_attributes.h"
78 
79 namespace slang {
80 
CreateFunctionPasses()81 void Backend::CreateFunctionPasses() {
82   if (!mPerFunctionPasses) {
83     mPerFunctionPasses = new llvm::legacy::FunctionPassManager(mpModule);
84 
85     llvm::PassManagerBuilder PMBuilder;
86     PMBuilder.OptLevel = mCodeGenOpts.OptimizationLevel;
87     PMBuilder.populateFunctionPassManager(*mPerFunctionPasses);
88   }
89 }
90 
CreateModulePasses()91 void Backend::CreateModulePasses() {
92   if (!mPerModulePasses) {
93     mPerModulePasses = new llvm::legacy::PassManager();
94 
95     llvm::PassManagerBuilder PMBuilder;
96     PMBuilder.OptLevel = mCodeGenOpts.OptimizationLevel;
97     PMBuilder.SizeLevel = mCodeGenOpts.OptimizeSize;
98     if (mCodeGenOpts.UnitAtATime) {
99       PMBuilder.DisableUnitAtATime = 0;
100     } else {
101       PMBuilder.DisableUnitAtATime = 1;
102     }
103 
104     if (mCodeGenOpts.UnrollLoops) {
105       PMBuilder.DisableUnrollLoops = 0;
106     } else {
107       PMBuilder.DisableUnrollLoops = 1;
108     }
109 
110     PMBuilder.populateModulePassManager(*mPerModulePasses);
111     // Add a pass to strip off unknown/unsupported attributes.
112     mPerModulePasses->add(createStripUnknownAttributesPass());
113   }
114 }
115 
CreateCodeGenPasses()116 bool Backend::CreateCodeGenPasses() {
117   if ((mOT != Slang::OT_Assembly) && (mOT != Slang::OT_Object))
118     return true;
119 
120   // Now we add passes for code emitting
121   if (mCodeGenPasses) {
122     return true;
123   } else {
124     mCodeGenPasses = new llvm::legacy::FunctionPassManager(mpModule);
125   }
126 
127   // Create the TargetMachine for generating code.
128   std::string Triple = mpModule->getTargetTriple();
129 
130   std::string Error;
131   const llvm::Target* TargetInfo =
132       llvm::TargetRegistry::lookupTarget(Triple, Error);
133   if (TargetInfo == nullptr) {
134     mDiagEngine.Report(clang::diag::err_fe_unable_to_create_target) << Error;
135     return false;
136   }
137 
138   // Target Machine Options
139   llvm::TargetOptions Options;
140 
141   // Use soft-float ABI for ARM (which is the target used by Slang during code
142   // generation).  Codegen still uses hardware FPU by default.  To use software
143   // floating point, add 'soft-float' feature to FeaturesStr below.
144   Options.FloatABIType = llvm::FloatABI::Soft;
145 
146   // BCC needs all unknown symbols resolved at compilation time. So we don't
147   // need any relocation model.
148   llvm::Reloc::Model RM = llvm::Reloc::Static;
149 
150   // This is set for the linker (specify how large of the virtual addresses we
151   // can access for all unknown symbols.)
152   llvm::CodeModel::Model CM;
153   if (mpModule->getDataLayout().getPointerSize() == 4) {
154     CM = llvm::CodeModel::Small;
155   } else {
156     // The target may have pointer size greater than 32 (e.g. x86_64
157     // architecture) may need large data address model
158     CM = llvm::CodeModel::Medium;
159   }
160 
161   // Setup feature string
162   std::string FeaturesStr;
163   if (mTargetOpts.CPU.size() || mTargetOpts.Features.size()) {
164     llvm::SubtargetFeatures Features;
165 
166     for (std::vector<std::string>::const_iterator
167              I = mTargetOpts.Features.begin(), E = mTargetOpts.Features.end();
168          I != E;
169          I++)
170       Features.AddFeature(*I);
171 
172     FeaturesStr = Features.getString();
173   }
174 
175   llvm::TargetMachine *TM =
176     TargetInfo->createTargetMachine(Triple, mTargetOpts.CPU, FeaturesStr,
177                                     Options, RM, CM);
178 
179   // Register allocation policy:
180   //  createFastRegisterAllocator: fast but bad quality
181   //  createGreedyRegisterAllocator: not so fast but good quality
182   llvm::RegisterRegAlloc::setDefault((mCodeGenOpts.OptimizationLevel == 0) ?
183                                      llvm::createFastRegisterAllocator :
184                                      llvm::createGreedyRegisterAllocator);
185 
186   llvm::CodeGenOpt::Level OptLevel = llvm::CodeGenOpt::Default;
187   if (mCodeGenOpts.OptimizationLevel == 0) {
188     OptLevel = llvm::CodeGenOpt::None;
189   } else if (mCodeGenOpts.OptimizationLevel == 3) {
190     OptLevel = llvm::CodeGenOpt::Aggressive;
191   }
192 
193   llvm::TargetMachine::CodeGenFileType CGFT =
194       llvm::TargetMachine::CGFT_AssemblyFile;
195   if (mOT == Slang::OT_Object) {
196     CGFT = llvm::TargetMachine::CGFT_ObjectFile;
197   }
198   if (TM->addPassesToEmitFile(*mCodeGenPasses, mBufferOutStream,
199                               CGFT, OptLevel)) {
200     mDiagEngine.Report(clang::diag::err_fe_unable_to_interface_with_target);
201     return false;
202   }
203 
204   return true;
205 }
206 
Backend(RSContext * Context,clang::DiagnosticsEngine * DiagEngine,const RSCCOptions & Opts,const clang::HeaderSearchOptions & HeaderSearchOpts,const clang::PreprocessorOptions & PreprocessorOpts,const clang::CodeGenOptions & CodeGenOpts,const clang::TargetOptions & TargetOpts,PragmaList * Pragmas,llvm::raw_ostream * OS,Slang::OutputType OT,clang::SourceManager & SourceMgr,bool AllowRSPrefix,bool IsFilterscript)207 Backend::Backend(RSContext *Context, clang::DiagnosticsEngine *DiagEngine,
208                  const RSCCOptions &Opts,
209                  const clang::HeaderSearchOptions &HeaderSearchOpts,
210                  const clang::PreprocessorOptions &PreprocessorOpts,
211                  const clang::CodeGenOptions &CodeGenOpts,
212                  const clang::TargetOptions &TargetOpts, PragmaList *Pragmas,
213                  llvm::raw_ostream *OS, Slang::OutputType OT,
214                  clang::SourceManager &SourceMgr, bool AllowRSPrefix,
215                  bool IsFilterscript)
216     : ASTConsumer(), mTargetOpts(TargetOpts), mpModule(nullptr), mpOS(OS),
217       mOT(OT), mGen(nullptr), mPerFunctionPasses(nullptr),
218       mPerModulePasses(nullptr), mCodeGenPasses(nullptr),
219       mBufferOutStream(*mpOS), mContext(Context),
220       mSourceMgr(SourceMgr), mASTPrint(Opts.mASTPrint), mAllowRSPrefix(AllowRSPrefix),
221       mIsFilterscript(IsFilterscript), mExportVarMetadata(nullptr),
222       mExportFuncMetadata(nullptr), mExportForEachNameMetadata(nullptr),
223       mExportForEachSignatureMetadata(nullptr),
224       mExportReduceMetadata(nullptr),
225       mExportTypeMetadata(nullptr), mRSObjectSlotsMetadata(nullptr),
226       mRefCount(mContext->getASTContext()),
227       mASTChecker(Context, Context->getTargetAPI(), IsFilterscript),
228       mForEachHandler(Context),
229       mLLVMContext(llvm::getGlobalContext()), mDiagEngine(*DiagEngine),
230       mCodeGenOpts(CodeGenOpts), mPragmas(Pragmas) {
231   mGen = CreateLLVMCodeGen(mDiagEngine, "", HeaderSearchOpts, PreprocessorOpts,
232       mCodeGenOpts, mLLVMContext);
233 }
234 
Initialize(clang::ASTContext & Ctx)235 void Backend::Initialize(clang::ASTContext &Ctx) {
236   mGen->Initialize(Ctx);
237 
238   mpModule = mGen->GetModule();
239 }
240 
HandleTranslationUnit(clang::ASTContext & Ctx)241 void Backend::HandleTranslationUnit(clang::ASTContext &Ctx) {
242   HandleTranslationUnitPre(Ctx);
243 
244   if (mASTPrint)
245     Ctx.getTranslationUnitDecl()->dump();
246 
247   mGen->HandleTranslationUnit(Ctx);
248 
249   // Here, we complete a translation unit (whole translation unit is now in LLVM
250   // IR). Now, interact with LLVM backend to generate actual machine code (asm
251   // or machine code, whatever.)
252 
253   // Silently ignore if we weren't initialized for some reason.
254   if (!mpModule)
255     return;
256 
257   llvm::Module *M = mGen->ReleaseModule();
258   if (!M) {
259     // The module has been released by IR gen on failures, do not double free.
260     mpModule = nullptr;
261     return;
262   }
263 
264   slangAssert(mpModule == M &&
265               "Unexpected module change during LLVM IR generation");
266 
267   // Insert #pragma information into metadata section of module
268   if (!mPragmas->empty()) {
269     llvm::NamedMDNode *PragmaMetadata =
270         mpModule->getOrInsertNamedMetadata(Slang::PragmaMetadataName);
271     for (PragmaList::const_iterator I = mPragmas->begin(), E = mPragmas->end();
272          I != E;
273          I++) {
274       llvm::SmallVector<llvm::Metadata*, 2> Pragma;
275       // Name goes first
276       Pragma.push_back(llvm::MDString::get(mLLVMContext, I->first));
277       // And then value
278       Pragma.push_back(llvm::MDString::get(mLLVMContext, I->second));
279 
280       // Create MDNode and insert into PragmaMetadata
281       PragmaMetadata->addOperand(
282           llvm::MDNode::get(mLLVMContext, Pragma));
283     }
284   }
285 
286   HandleTranslationUnitPost(mpModule);
287 
288   // Create passes for optimization and code emission
289 
290   // Create and run per-function passes
291   CreateFunctionPasses();
292   if (mPerFunctionPasses) {
293     mPerFunctionPasses->doInitialization();
294 
295     for (llvm::Module::iterator I = mpModule->begin(), E = mpModule->end();
296          I != E;
297          I++)
298       if (!I->isDeclaration())
299         mPerFunctionPasses->run(*I);
300 
301     mPerFunctionPasses->doFinalization();
302   }
303 
304   // Create and run module passes
305   CreateModulePasses();
306   if (mPerModulePasses)
307     mPerModulePasses->run(*mpModule);
308 
309   switch (mOT) {
310     case Slang::OT_Assembly:
311     case Slang::OT_Object: {
312       if (!CreateCodeGenPasses())
313         return;
314 
315       mCodeGenPasses->doInitialization();
316 
317       for (llvm::Module::iterator I = mpModule->begin(), E = mpModule->end();
318           I != E;
319           I++)
320         if (!I->isDeclaration())
321           mCodeGenPasses->run(*I);
322 
323       mCodeGenPasses->doFinalization();
324       break;
325     }
326     case Slang::OT_LLVMAssembly: {
327       llvm::legacy::PassManager *LLEmitPM = new llvm::legacy::PassManager();
328       LLEmitPM->add(llvm::createPrintModulePass(mBufferOutStream));
329       LLEmitPM->run(*mpModule);
330       break;
331     }
332     case Slang::OT_Bitcode: {
333       writeBitcode(mBufferOutStream, *mpModule, getTargetAPI(),
334                    mCodeGenOpts.OptimizationLevel, mCodeGenOpts.getDebugInfo());
335       break;
336     }
337     case Slang::OT_Nothing: {
338       return;
339     }
340     default: {
341       slangAssert(false && "Unknown output type");
342     }
343   }
344 }
345 
HandleTagDeclDefinition(clang::TagDecl * D)346 void Backend::HandleTagDeclDefinition(clang::TagDecl *D) {
347   mGen->HandleTagDeclDefinition(D);
348 }
349 
CompleteTentativeDefinition(clang::VarDecl * D)350 void Backend::CompleteTentativeDefinition(clang::VarDecl *D) {
351   mGen->CompleteTentativeDefinition(D);
352 }
353 
~Backend()354 Backend::~Backend() {
355   delete mpModule;
356   delete mGen;
357   delete mPerFunctionPasses;
358   delete mPerModulePasses;
359   delete mCodeGenPasses;
360 }
361 
362 // 1) Add zero initialization of local RS object types
AnnotateFunction(clang::FunctionDecl * FD)363 void Backend::AnnotateFunction(clang::FunctionDecl *FD) {
364   if (FD &&
365       FD->hasBody() &&
366       !Slang::IsLocInRSHeaderFile(FD->getLocation(), mSourceMgr)) {
367     mRefCount.Init();
368     mRefCount.SetDeclContext(FD);
369     mRefCount.Visit(FD->getBody());
370   }
371 }
372 
HandleTopLevelDecl(clang::DeclGroupRef D)373 bool Backend::HandleTopLevelDecl(clang::DeclGroupRef D) {
374   // Find and remember the types for rs_allocation and rs_script_call_t so
375   // they can be used later for translating rsForEach() calls.
376   for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end();
377        (mContext->getAllocationType().isNull() ||
378         mContext->getScriptCallType().isNull()) &&
379        I != E; I++) {
380     if (clang::TypeDecl* TD = llvm::dyn_cast<clang::TypeDecl>(*I)) {
381       clang::StringRef TypeName = TD->getName();
382       if (TypeName.equals("rs_allocation")) {
383         mContext->setAllocationType(TD);
384       } else if (TypeName.equals("rs_script_call_t")) {
385         mContext->setScriptCallType(TD);
386       }
387     }
388   }
389 
390   // Disallow user-defined functions with prefix "rs"
391   if (!mAllowRSPrefix) {
392     // Iterate all function declarations in the program.
393     for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end();
394          I != E; I++) {
395       clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
396       if (FD == nullptr)
397         continue;
398       if (!FD->getName().startswith("rs"))  // Check prefix
399         continue;
400       if (!Slang::IsLocInRSHeaderFile(FD->getLocation(), mSourceMgr))
401         mContext->ReportError(FD->getLocation(),
402                               "invalid function name prefix, "
403                               "\"rs\" is reserved: '%0'")
404             << FD->getName();
405     }
406   }
407 
408   for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end(); I != E; I++) {
409     clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
410     if (FD) {
411       // Handle forward reference from pragma (see
412       // RSReducePragmaHandler::HandlePragma for backward reference).
413       mContext->markUsedByReducePragma(FD, RSContext::CheckNameYes);
414       if (FD->isGlobal()) {
415         // Check that we don't have any array parameters being misinterpreted as
416         // kernel pointers due to the C type system's array to pointer decay.
417         size_t numParams = FD->getNumParams();
418         for (size_t i = 0; i < numParams; i++) {
419           const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
420           clang::QualType QT = PVD->getOriginalType();
421           if (QT->isArrayType()) {
422             mContext->ReportError(
423                 PVD->getTypeSpecStartLoc(),
424                 "exported function parameters may not have array type: %0")
425                 << QT;
426           }
427         }
428         AnnotateFunction(FD);
429       }
430     }
431 
432     if (getTargetAPI() >= SLANG_FEATURE_SINGLE_SOURCE_API) {
433       if (FD && FD->hasBody() &&
434           !Slang::IsLocInRSHeaderFile(FD->getLocation(), mSourceMgr)) {
435         if (FD->hasAttr<clang::KernelAttr>()) {
436           // Log functions with attribute "kernel" by their names, and assign
437           // them slot numbers. Any other function cannot be used in a
438           // rsForEach() or rsForEachWithOptions() call, including old-style
439           // kernel functions which are defined without the "kernel" attribute.
440           mContext->addForEach(FD);
441         }
442         // Look for any kernel launch calls and translate them into using the
443         // internal API.
444         // Report a compiler error on kernel launches inside a kernel.
445         mForEachHandler.handleForEachCalls(FD, getTargetAPI());
446       }
447     }
448   }
449 
450   return mGen->HandleTopLevelDecl(D);
451 }
452 
HandleTranslationUnitPre(clang::ASTContext & C)453 void Backend::HandleTranslationUnitPre(clang::ASTContext &C) {
454   clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
455 
456   if (!mContext->processReducePragmas(this))
457     return;
458 
459   // If we have an invalid RS/FS AST, don't check further.
460   if (!mASTChecker.Validate()) {
461     return;
462   }
463 
464   if (mIsFilterscript) {
465     mContext->addPragma("rs_fp_relaxed", "");
466   }
467 
468   int version = mContext->getVersion();
469   if (version == 0) {
470     // Not setting a version is an error
471     mDiagEngine.Report(
472         mSourceMgr.getLocForEndOfFile(mSourceMgr.getMainFileID()),
473         mDiagEngine.getCustomDiagID(
474             clang::DiagnosticsEngine::Error,
475             "missing pragma for version in source file"));
476   } else {
477     slangAssert(version == 1);
478   }
479 
480   if (mContext->getReflectJavaPackageName().empty()) {
481     mDiagEngine.Report(
482         mSourceMgr.getLocForEndOfFile(mSourceMgr.getMainFileID()),
483         mDiagEngine.getCustomDiagID(clang::DiagnosticsEngine::Error,
484                                     "missing \"#pragma rs "
485                                     "java_package_name(com.foo.bar)\" "
486                                     "in source file"));
487     return;
488   }
489 
490   // Create a static global destructor if necessary (to handle RS object
491   // runtime cleanup).
492   clang::FunctionDecl *FD = mRefCount.CreateStaticGlobalDtor();
493   if (FD) {
494     HandleTopLevelDecl(clang::DeclGroupRef(FD));
495   }
496 
497   // Process any static function declarations
498   for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
499           E = TUDecl->decls_end(); I != E; I++) {
500     if ((I->getKind() >= clang::Decl::firstFunction) &&
501         (I->getKind() <= clang::Decl::lastFunction)) {
502       clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
503       if (FD && !FD->isGlobal()) {
504         AnnotateFunction(FD);
505       }
506     }
507   }
508 }
509 
510 ///////////////////////////////////////////////////////////////////////////////
dumpExportVarInfo(llvm::Module * M)511 void Backend::dumpExportVarInfo(llvm::Module *M) {
512   int slotCount = 0;
513   if (mExportVarMetadata == nullptr)
514     mExportVarMetadata = M->getOrInsertNamedMetadata(RS_EXPORT_VAR_MN);
515 
516   llvm::SmallVector<llvm::Metadata *, 2> ExportVarInfo;
517 
518   // We emit slot information (#rs_object_slots) for any reference counted
519   // RS type or pointer (which can also be bound).
520 
521   for (RSContext::const_export_var_iterator I = mContext->export_vars_begin(),
522           E = mContext->export_vars_end();
523        I != E;
524        I++) {
525     const RSExportVar *EV = *I;
526     const RSExportType *ET = EV->getType();
527     bool countsAsRSObject = false;
528 
529     // Variable name
530     ExportVarInfo.push_back(
531         llvm::MDString::get(mLLVMContext, EV->getName().c_str()));
532 
533     // Type name
534     switch (ET->getClass()) {
535       case RSExportType::ExportClassPrimitive: {
536         const RSExportPrimitiveType *PT =
537             static_cast<const RSExportPrimitiveType*>(ET);
538         ExportVarInfo.push_back(
539             llvm::MDString::get(
540               mLLVMContext, llvm::utostr_32(PT->getType())));
541         if (PT->isRSObjectType()) {
542           countsAsRSObject = true;
543         }
544         break;
545       }
546       case RSExportType::ExportClassPointer: {
547         ExportVarInfo.push_back(
548             llvm::MDString::get(
549               mLLVMContext, ("*" + static_cast<const RSExportPointerType*>(ET)
550                 ->getPointeeType()->getName()).c_str()));
551         break;
552       }
553       case RSExportType::ExportClassMatrix: {
554         ExportVarInfo.push_back(
555             llvm::MDString::get(
556               mLLVMContext, llvm::utostr_32(
557                   /* TODO Strange value.  This pushes just a number, quite
558                    * different than the other cases.  What is this used for?
559                    * These are the metadata values that some partner drivers
560                    * want to reference (for TBAA, etc.). We may want to look
561                    * at whether these provide any reasonable value (or have
562                    * distinct enough values to actually depend on).
563                    */
564                 DataTypeRSMatrix2x2 +
565                 static_cast<const RSExportMatrixType*>(ET)->getDim() - 2)));
566         break;
567       }
568       case RSExportType::ExportClassVector:
569       case RSExportType::ExportClassConstantArray:
570       case RSExportType::ExportClassRecord: {
571         ExportVarInfo.push_back(
572             llvm::MDString::get(mLLVMContext,
573               EV->getType()->getName().c_str()));
574         break;
575       }
576     }
577 
578     mExportVarMetadata->addOperand(
579         llvm::MDNode::get(mLLVMContext, ExportVarInfo));
580     ExportVarInfo.clear();
581 
582     if (mRSObjectSlotsMetadata == nullptr) {
583       mRSObjectSlotsMetadata =
584           M->getOrInsertNamedMetadata(RS_OBJECT_SLOTS_MN);
585     }
586 
587     if (countsAsRSObject) {
588       mRSObjectSlotsMetadata->addOperand(llvm::MDNode::get(mLLVMContext,
589           llvm::MDString::get(mLLVMContext, llvm::utostr_32(slotCount))));
590     }
591 
592     slotCount++;
593   }
594 }
595 
dumpExportFunctionInfo(llvm::Module * M)596 void Backend::dumpExportFunctionInfo(llvm::Module *M) {
597   if (mExportFuncMetadata == nullptr)
598     mExportFuncMetadata =
599         M->getOrInsertNamedMetadata(RS_EXPORT_FUNC_MN);
600 
601   llvm::SmallVector<llvm::Metadata *, 1> ExportFuncInfo;
602 
603   for (RSContext::const_export_func_iterator
604           I = mContext->export_funcs_begin(),
605           E = mContext->export_funcs_end();
606        I != E;
607        I++) {
608     const RSExportFunc *EF = *I;
609 
610     // Function name
611     if (!EF->hasParam()) {
612       ExportFuncInfo.push_back(llvm::MDString::get(mLLVMContext,
613                                                    EF->getName().c_str()));
614     } else {
615       llvm::Function *F = M->getFunction(EF->getName());
616       llvm::Function *HelperFunction;
617       const std::string HelperFunctionName(".helper_" + EF->getName());
618 
619       slangAssert(F && "Function marked as exported disappeared in Bitcode");
620 
621       // Create helper function
622       {
623         llvm::StructType *HelperFunctionParameterTy = nullptr;
624         std::vector<bool> isStructInput;
625 
626         if (!F->getArgumentList().empty()) {
627           std::vector<llvm::Type*> HelperFunctionParameterTys;
628           for (llvm::Function::arg_iterator AI = F->arg_begin(),
629                    AE = F->arg_end(); AI != AE; AI++) {
630               if (AI->getType()->isPointerTy() && AI->getType()->getPointerElementType()->isStructTy()) {
631                   HelperFunctionParameterTys.push_back(AI->getType()->getPointerElementType());
632                   isStructInput.push_back(true);
633               } else {
634                   HelperFunctionParameterTys.push_back(AI->getType());
635                   isStructInput.push_back(false);
636               }
637           }
638           HelperFunctionParameterTy =
639               llvm::StructType::get(mLLVMContext, HelperFunctionParameterTys);
640         }
641 
642         if (!EF->checkParameterPacketType(HelperFunctionParameterTy)) {
643           fprintf(stderr, "Failed to export function %s: parameter type "
644                           "mismatch during creation of helper function.\n",
645                   EF->getName().c_str());
646 
647           const RSExportRecordType *Expected = EF->getParamPacketType();
648           if (Expected) {
649             fprintf(stderr, "Expected:\n");
650             Expected->getLLVMType()->dump();
651           }
652           if (HelperFunctionParameterTy) {
653             fprintf(stderr, "Got:\n");
654             HelperFunctionParameterTy->dump();
655           }
656         }
657 
658         std::vector<llvm::Type*> Params;
659         if (HelperFunctionParameterTy) {
660           llvm::PointerType *HelperFunctionParameterTyP =
661               llvm::PointerType::getUnqual(HelperFunctionParameterTy);
662           Params.push_back(HelperFunctionParameterTyP);
663         }
664 
665         llvm::FunctionType * HelperFunctionType =
666             llvm::FunctionType::get(F->getReturnType(),
667                                     Params,
668                                     /* IsVarArgs = */false);
669 
670         HelperFunction =
671             llvm::Function::Create(HelperFunctionType,
672                                    llvm::GlobalValue::ExternalLinkage,
673                                    HelperFunctionName,
674                                    M);
675 
676         HelperFunction->addFnAttr(llvm::Attribute::NoInline);
677         HelperFunction->setCallingConv(F->getCallingConv());
678 
679         // Create helper function body
680         {
681           llvm::Argument *HelperFunctionParameter =
682               &(*HelperFunction->arg_begin());
683           llvm::BasicBlock *BB =
684               llvm::BasicBlock::Create(mLLVMContext, "entry", HelperFunction);
685           llvm::IRBuilder<> *IB = new llvm::IRBuilder<>(BB);
686           llvm::SmallVector<llvm::Value*, 6> Params;
687           llvm::Value *Idx[2];
688 
689           Idx[0] =
690               llvm::ConstantInt::get(llvm::Type::getInt32Ty(mLLVMContext), 0);
691 
692           // getelementptr and load instruction for all elements in
693           // parameter .p
694           for (size_t i = 0; i < EF->getNumParameters(); i++) {
695             // getelementptr
696             Idx[1] = llvm::ConstantInt::get(
697               llvm::Type::getInt32Ty(mLLVMContext), i);
698 
699             llvm::Value *Ptr = NULL;
700 
701             Ptr = IB->CreateInBoundsGEP(HelperFunctionParameter, Idx);
702 
703             // Load is only required for non-struct ptrs
704             if (isStructInput[i]) {
705                 Params.push_back(Ptr);
706             } else {
707                 llvm::Value *V = IB->CreateLoad(Ptr);
708                 Params.push_back(V);
709             }
710           }
711 
712           // Call and pass the all elements as parameter to F
713           llvm::CallInst *CI = IB->CreateCall(F, Params);
714 
715           CI->setCallingConv(F->getCallingConv());
716 
717           if (F->getReturnType() == llvm::Type::getVoidTy(mLLVMContext)) {
718             IB->CreateRetVoid();
719           } else {
720             IB->CreateRet(CI);
721           }
722 
723           delete IB;
724         }
725       }
726 
727       ExportFuncInfo.push_back(
728           llvm::MDString::get(mLLVMContext, HelperFunctionName.c_str()));
729     }
730 
731     mExportFuncMetadata->addOperand(
732         llvm::MDNode::get(mLLVMContext, ExportFuncInfo));
733     ExportFuncInfo.clear();
734   }
735 }
736 
dumpExportForEachInfo(llvm::Module * M)737 void Backend::dumpExportForEachInfo(llvm::Module *M) {
738   if (mExportForEachNameMetadata == nullptr) {
739     mExportForEachNameMetadata =
740         M->getOrInsertNamedMetadata(RS_EXPORT_FOREACH_NAME_MN);
741   }
742   if (mExportForEachSignatureMetadata == nullptr) {
743     mExportForEachSignatureMetadata =
744         M->getOrInsertNamedMetadata(RS_EXPORT_FOREACH_MN);
745   }
746 
747   llvm::SmallVector<llvm::Metadata *, 1> ExportForEachName;
748   llvm::SmallVector<llvm::Metadata *, 1> ExportForEachInfo;
749 
750   for (RSContext::const_export_foreach_iterator
751           I = mContext->export_foreach_begin(),
752           E = mContext->export_foreach_end();
753        I != E;
754        I++) {
755     const RSExportForEach *EFE = *I;
756 
757     ExportForEachName.push_back(
758         llvm::MDString::get(mLLVMContext, EFE->getName().c_str()));
759 
760     mExportForEachNameMetadata->addOperand(
761         llvm::MDNode::get(mLLVMContext, ExportForEachName));
762     ExportForEachName.clear();
763 
764     ExportForEachInfo.push_back(
765         llvm::MDString::get(mLLVMContext,
766                             llvm::utostr_32(EFE->getSignatureMetadata())));
767 
768     mExportForEachSignatureMetadata->addOperand(
769         llvm::MDNode::get(mLLVMContext, ExportForEachInfo));
770     ExportForEachInfo.clear();
771   }
772 }
773 
dumpExportReduceInfo(llvm::Module * M)774 void Backend::dumpExportReduceInfo(llvm::Module *M) {
775   if (!mExportReduceMetadata) {
776     mExportReduceMetadata =
777       M->getOrInsertNamedMetadata(RS_EXPORT_REDUCE_MN);
778   }
779 
780   llvm::SmallVector<llvm::Metadata *, 6> ExportReduceInfo;
781   // Add operand to ExportReduceInfo, padding out missing operands with
782   // nullptr.
783   auto addOperand = [&ExportReduceInfo](uint32_t Idx, llvm::Metadata *N) {
784     while (Idx > ExportReduceInfo.size())
785       ExportReduceInfo.push_back(nullptr);
786     ExportReduceInfo.push_back(N);
787   };
788   // Add string operand to ExportReduceInfo, padding out missing operands
789   // with nullptr.
790   // If string is empty, then do not add it unless Always is true.
791   auto addString = [&addOperand, this](uint32_t Idx, const std::string &S,
792                                        bool Always = true) {
793     if (Always || !S.empty())
794       addOperand(Idx, llvm::MDString::get(mLLVMContext, S));
795   };
796 
797   // Add the description of the reduction kernels to the metadata node.
798   for (auto I = mContext->export_reduce_begin(),
799             E = mContext->export_reduce_end();
800        I != E; ++I) {
801     ExportReduceInfo.clear();
802 
803     int Idx = 0;
804 
805     addString(Idx++, (*I)->getNameReduce());
806 
807     addOperand(Idx++, llvm::MDString::get(mLLVMContext, llvm::utostr_32((*I)->getAccumulatorTypeSize())));
808 
809     llvm::SmallVector<llvm::Metadata *, 2> Accumulator;
810     Accumulator.push_back(
811       llvm::MDString::get(mLLVMContext, (*I)->getNameAccumulator()));
812     Accumulator.push_back(llvm::MDString::get(
813       mLLVMContext,
814       llvm::utostr_32((*I)->getAccumulatorSignatureMetadata())));
815     addOperand(Idx++, llvm::MDTuple::get(mLLVMContext, Accumulator));
816 
817     addString(Idx++, (*I)->getNameInitializer(), false);
818     addString(Idx++, (*I)->getNameCombiner(), false);
819     addString(Idx++, (*I)->getNameOutConverter(), false);
820     addString(Idx++, (*I)->getNameHalter(), false);
821 
822     mExportReduceMetadata->addOperand(
823       llvm::MDTuple::get(mLLVMContext, ExportReduceInfo));
824   }
825 }
826 
dumpExportTypeInfo(llvm::Module * M)827 void Backend::dumpExportTypeInfo(llvm::Module *M) {
828   llvm::SmallVector<llvm::Metadata *, 1> ExportTypeInfo;
829 
830   for (RSContext::const_export_type_iterator
831           I = mContext->export_types_begin(),
832           E = mContext->export_types_end();
833        I != E;
834        I++) {
835     // First, dump type name list to export
836     const RSExportType *ET = I->getValue();
837 
838     ExportTypeInfo.clear();
839     // Type name
840     ExportTypeInfo.push_back(
841         llvm::MDString::get(mLLVMContext, ET->getName().c_str()));
842 
843     if (ET->getClass() == RSExportType::ExportClassRecord) {
844       const RSExportRecordType *ERT =
845           static_cast<const RSExportRecordType*>(ET);
846 
847       if (mExportTypeMetadata == nullptr)
848         mExportTypeMetadata =
849             M->getOrInsertNamedMetadata(RS_EXPORT_TYPE_MN);
850 
851       mExportTypeMetadata->addOperand(
852           llvm::MDNode::get(mLLVMContext, ExportTypeInfo));
853 
854       // Now, export struct field information to %[struct name]
855       std::string StructInfoMetadataName("%");
856       StructInfoMetadataName.append(ET->getName());
857       llvm::NamedMDNode *StructInfoMetadata =
858           M->getOrInsertNamedMetadata(StructInfoMetadataName);
859       llvm::SmallVector<llvm::Metadata *, 3> FieldInfo;
860 
861       slangAssert(StructInfoMetadata->getNumOperands() == 0 &&
862                   "Metadata with same name was created before");
863       for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(),
864               FE = ERT->fields_end();
865            FI != FE;
866            FI++) {
867         const RSExportRecordType::Field *F = *FI;
868 
869         // 1. field name
870         FieldInfo.push_back(llvm::MDString::get(mLLVMContext,
871                                                 F->getName().c_str()));
872 
873         // 2. field type name
874         FieldInfo.push_back(
875             llvm::MDString::get(mLLVMContext,
876                                 F->getType()->getName().c_str()));
877 
878         StructInfoMetadata->addOperand(
879             llvm::MDNode::get(mLLVMContext, FieldInfo));
880         FieldInfo.clear();
881       }
882     }   // ET->getClass() == RSExportType::ExportClassRecord
883   }
884 }
885 
HandleTranslationUnitPost(llvm::Module * M)886 void Backend::HandleTranslationUnitPost(llvm::Module *M) {
887 
888   if (!mContext->is64Bit()) {
889     M->setDataLayout("e-p:32:32-i64:64-v128:64:128-n32-S64");
890   }
891 
892   if (!mContext->processExports())
893     return;
894 
895   if (mContext->hasExportVar())
896     dumpExportVarInfo(M);
897 
898   if (mContext->hasExportFunc())
899     dumpExportFunctionInfo(M);
900 
901   if (mContext->hasExportForEach())
902     dumpExportForEachInfo(M);
903 
904   if (mContext->hasExportReduce())
905     dumpExportReduceInfo(M);
906 
907   if (mContext->hasExportType())
908     dumpExportTypeInfo(M);
909 }
910 
911 }  // namespace slang
912