1 /*
2  * Copyright 2016, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "RSAllocationUtils.h"
18 
19 #include "llvm/ADT/StringRef.h"
20 #include "llvm/IR/Constants.h"
21 #include "llvm/IR/GlobalVariable.h"
22 #include "llvm/IR/Instructions.h"
23 #include "llvm/IR/Module.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/raw_ostream.h"
26 
27 #include "cxxabi.h"
28 
29 #include <sstream>
30 #include <unordered_map>
31 
32 #define DEBUG_TYPE "rs2spirv-rs-allocation-utils"
33 
34 using namespace llvm;
35 
36 namespace rs2spirv {
37 
isRSAllocation(const GlobalVariable & GV)38 bool isRSAllocation(const GlobalVariable &GV) {
39   auto *PT = cast<PointerType>(GV.getType());
40   DEBUG(PT->dump());
41 
42   auto *VT = PT->getElementType();
43   DEBUG(VT->dump());
44   std::string TypeName;
45   raw_string_ostream RSO(TypeName);
46   VT->print(RSO);
47   RSO.str(); // Force flush.
48   DEBUG(dbgs() << "TypeName: " << TypeName << '\n');
49 
50   return TypeName.find("struct.rs_allocation") != std::string::npos;
51 }
52 
getRSAllocationInfo(Module & M,SmallVectorImpl<RSAllocationInfo> & Allocs)53 bool getRSAllocationInfo(Module &M, SmallVectorImpl<RSAllocationInfo> &Allocs) {
54   DEBUG(dbgs() << "getRSAllocationInfo\n");
55   for (auto &GV : M.globals()) {
56     if (GV.isDeclaration() || !isRSAllocation(GV))
57       continue;
58 
59     Allocs.push_back({'%' + GV.getName().str(), None, &GV, -1});
60   }
61 
62   return true;
63 }
64 
65 // Collect Allocation access calls into the Calls
66 // Also update Allocs with assigned ID.
67 // After calling this function, Allocs would contain the mapping from
68 // GV name to the corresponding ID.
getRSAllocAccesses(SmallVectorImpl<RSAllocationInfo> & Allocs,SmallVectorImpl<RSAllocationCallInfo> & Calls)69 bool getRSAllocAccesses(SmallVectorImpl<RSAllocationInfo> &Allocs,
70                         SmallVectorImpl<RSAllocationCallInfo> &Calls) {
71   DEBUG(dbgs() << "getRSGEATCalls\n");
72   DEBUG(dbgs() << "\n\n~~~~~~~~~~~~~~~~~~~~~\n\n");
73 
74   std::unordered_map<const Value *, const GlobalVariable *> Mapping;
75   int id_assigned = 0;
76 
77   for (auto &A : Allocs) {
78     auto *GV = A.GlobalVar;
79     std::vector<User *> WorkList(GV->user_begin(), GV->user_end());
80     size_t Idx = 0;
81 
82     while (Idx < WorkList.size()) {
83       auto *U = WorkList[Idx];
84       DEBUG(dbgs() << "Visiting ");
85       DEBUG(U->dump());
86       ++Idx;
87       auto It = Mapping.find(U);
88       if (It != Mapping.end()) {
89         if (It->second == GV) {
90           continue;
91         } else {
92           errs() << "Duplicate global mapping discovered!\n";
93           errs() << "\nGlobal: ";
94           GV->print(errs());
95           errs() << "\nExisting mapping: ";
96           It->second->print(errs());
97           errs() << "\nUser: ";
98           U->print(errs());
99           errs() << '\n';
100 
101           return false;
102         }
103       }
104 
105       Mapping[U] = GV;
106       DEBUG(dbgs() << "New mapping: ");
107       DEBUG(U->print(dbgs()));
108       DEBUG(dbgs() << " -> " << GV->getName() << '\n');
109 
110       if (auto *FCall = dyn_cast<CallInst>(U)) {
111         if (auto *F = FCall->getCalledFunction()) {
112           const auto FName = F->getName();
113           DEBUG(dbgs() << "Discovered function call to : " << FName << '\n');
114           // Treat memcpy as moves for the purpose of this analysis
115           if (FName.startswith("llvm.memcpy")) {
116             assert(FCall->getNumArgOperands() > 0);
117             Value *CopyDest = FCall->getArgOperand(0);
118             // We are interested in the users of the dest operand of
119             // memcpy here
120             Value *LocalCopy = CopyDest->stripPointerCasts();
121             User *NewU = dyn_cast<User>(LocalCopy);
122             assert(NewU);
123             WorkList.push_back(NewU);
124             continue;
125           }
126 
127           char *demangled = __cxxabiv1::__cxa_demangle(
128               FName.str().c_str(), nullptr, nullptr, nullptr);
129           if (!demangled)
130             continue;
131           const StringRef DemangledNameRef(demangled);
132           DEBUG(dbgs() << "Demangled name: " << DemangledNameRef << '\n');
133 
134           const StringRef GEAPrefix = "rsGetElementAt_";
135           const StringRef SEAPrefix = "rsSetElementAt_";
136           const StringRef DIMXPrefix = "rsAllocationGetDimX";
137           assert(GEAPrefix.size() == SEAPrefix.size());
138 
139           const bool IsGEA = DemangledNameRef.startswith(GEAPrefix);
140           const bool IsSEA = DemangledNameRef.startswith(SEAPrefix);
141           const bool IsDIMX = DemangledNameRef.startswith(DIMXPrefix);
142 
143           assert(IsGEA || IsSEA || IsDIMX);
144           if (!A.hasID()) {
145             A.assignID(id_assigned++);
146           }
147 
148           if (IsGEA || IsSEA) {
149             DEBUG(dbgs() << "Found rsAlloc function!\n");
150 
151             const auto Kind =
152                 IsGEA ? RSAllocAccessKind::GEA : RSAllocAccessKind::SEA;
153 
154             const auto RSElementTy =
155                 DemangledNameRef.drop_front(GEAPrefix.size());
156 
157             Calls.push_back({A, FCall, Kind, RSElementTy.str()});
158             continue;
159           } else if (DemangledNameRef.startswith(GEAPrefix.drop_back()) ||
160                      DemangledNameRef.startswith(SEAPrefix.drop_back())) {
161             errs() << "Untyped accesses to global rs_allocations are not "
162                       "supported.\n";
163             return false;
164           } else if (IsDIMX) {
165             DEBUG(dbgs() << "Found rsAllocationGetDimX function!\n");
166             const auto Kind = RSAllocAccessKind::DIMX;
167             Calls.push_back({A, FCall, Kind, ""});
168           }
169         }
170       }
171 
172       // TODO: Consider using set-like container to reduce computational
173       // complexity.
174       for (auto *NewU : U->users())
175         if (std::find(WorkList.begin(), WorkList.end(), NewU) == WorkList.end())
176           WorkList.push_back(NewU);
177     }
178   }
179 
180   std::unordered_map<const GlobalVariable *, std::string> GVAccessTypes;
181 
182   for (auto &Access : Calls) {
183     auto AccessElemTyIt = GVAccessTypes.find(Access.RSAlloc.GlobalVar);
184     if (AccessElemTyIt != GVAccessTypes.end() &&
185         AccessElemTyIt->second != Access.RSElementTy) {
186       errs() << "Could not infere element type for: ";
187       Access.RSAlloc.GlobalVar->print(errs());
188       errs() << '\n';
189       return false;
190     } else if (AccessElemTyIt == GVAccessTypes.end()) {
191       GVAccessTypes.emplace(Access.RSAlloc.GlobalVar, Access.RSElementTy);
192       Access.RSAlloc.RSElementType = Access.RSElementTy;
193     }
194   }
195 
196   DEBUG(dbgs() << "\n\n~~~~~~~~~~~~~~~~~~~~~\n\n");
197   return true;
198 }
199 
solidifyRSAllocAccess(Module & M,RSAllocationCallInfo CallInfo)200 bool solidifyRSAllocAccess(Module &M, RSAllocationCallInfo CallInfo) {
201   DEBUG(dbgs() << "solidifyRSAllocAccess " << CallInfo.RSAlloc.VarName << '\n');
202   auto *FCall = CallInfo.FCall;
203   auto *Fun = FCall->getCalledFunction();
204   assert(Fun);
205 
206   StringRef FName;
207   if (CallInfo.Kind == RSAllocAccessKind::DIMX)
208     FName = "rsAllocationGetDimX";
209   else
210     FName = Fun->getName();
211 
212   std::ostringstream OSS;
213   OSS << "__rsov_" << FName.str();
214   // Make up uint32_t F(uint32_t)
215   Type *UInt32Ty = IntegerType::get(M.getContext(), 32);
216   auto *NewFT = FunctionType::get(UInt32Ty, ArrayRef<Type *>(UInt32Ty), false);
217 
218   auto *NewF = Function::Create(NewFT, // Fun->getFunctionType(),
219                                 Function::ExternalLinkage, OSS.str(), &M);
220   FCall->setCalledFunction(NewF);
221   FCall->setArgOperand(0, ConstantInt::get(UInt32Ty, 0, false));
222   NewF->setAttributes(Fun->getAttributes());
223 
224   DEBUG(M.dump());
225 
226   return true;
227 }
228 
229 } // namespace rs2spirv
230