1 /*
2  * Copyright 2017, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "RemoveNonkernelsPass.h"
18 
19 #include "llvm/ADT/Triple.h"
20 #include "llvm/IR/Constants.h"
21 #include "llvm/IR/Module.h"
22 #include "llvm/Pass.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/raw_ostream.h"
25 
26 #include "Context.h"
27 
28 #define DEBUG_TYPE "rs2spirv-remove"
29 
30 using namespace llvm;
31 
32 namespace rs2spirv {
33 
34 namespace {
35 
HandleTargetTriple(llvm::Module & M)36 void HandleTargetTriple(llvm::Module &M) {
37   Triple TT(M.getTargetTriple());
38   auto Arch = TT.getArch();
39 
40   StringRef NewTriple;
41   switch (Arch) {
42   default:
43     llvm_unreachable("Unrecognized architecture");
44     break;
45   case Triple::arm:
46     NewTriple = "spir-unknown-unknown";
47     break;
48   case Triple::aarch64:
49     NewTriple = "spir64-unknown-unknown";
50     break;
51   case Triple::spir:
52   case Triple::spir64:
53     DEBUG(dbgs() << "!!! Already a spir triple !!!\n");
54   }
55 
56   DEBUG(dbgs() << "New triple:\t" << NewTriple << "\n");
57   M.setTargetTriple(NewTriple);
58 }
59 
60 class RemoveNonkernelsPass : public ModulePass {
61 public:
62   static char ID;
RemoveNonkernelsPass()63   explicit RemoveNonkernelsPass() : ModulePass(ID) {}
64 
getPassName() const65   const char *getPassName() const override { return "RemoveNonkernelsPass"; }
66 
runOnModule(Module & M)67   bool runOnModule(Module &M) override {
68     DEBUG(dbgs() << "RemoveNonkernelsPass\n");
69     DEBUG(M.dump());
70 
71     HandleTargetTriple(M);
72 
73     rs2spirv::Context &Ctxt = rs2spirv::Context::getInstance();
74 
75     if (Ctxt.getNumForEachKernel() == 0) {
76       DEBUG(dbgs() << "RemoveNonkernelsPass detected no kernel\n");
77       // Returns false, since no modification is made to the Module.
78       return false;
79     }
80 
81     std::vector<Function *> Functions;
82     for (auto &F : M.functions()) {
83       Functions.push_back(&F);
84     }
85 
86     for (auto &F : Functions) {
87       if (F->isDeclaration())
88         continue;
89 
90       if (Ctxt.isForEachKernel(F->getName())) {
91         continue; // Skip kernels.
92       }
93 
94       F->replaceAllUsesWith(UndefValue::get((Type *)F->getType()));
95       F->eraseFromParent();
96 
97       DEBUG(dbgs() << "Removed:\t" << F->getName() << '\n');
98     }
99 
100     DEBUG(M.dump());
101     DEBUG(dbgs() << "Done removal\n");
102 
103     // Returns true, because the pass modifies the Module.
104     return true;
105   }
106 };
107 
108 } // namespace
109 
110 char RemoveNonkernelsPass::ID = 0;
111 
createRemoveNonkernelsPass()112 ModulePass *createRemoveNonkernelsPass() {
113   return new RemoveNonkernelsPass();
114 }
115 
116 } // namespace rs2spirv
117