1 //===-- AMDGPUOpenCLImageTypeLoweringPass.cpp -----------------------------===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 /// \file
11 /// This pass resolves calls to OpenCL image attribute, image resource ID and
12 /// sampler resource ID getter functions.
13 ///
14 /// Image attributes (size and format) are expected to be passed to the kernel
15 /// as kernel arguments immediately following the image argument itself,
16 /// therefore this pass adds image size and format arguments to the kernel
17 /// functions in the module. The kernel functions with image arguments are
18 /// re-created using the new signature. The new arguments are added to the
19 /// kernel metadata with kernel_arg_type set to "image_size" or "image_format".
20 /// Note: this pass may invalidate pointers to functions.
21 ///
22 /// Resource IDs of read-only images, write-only images and samplers are
23 /// defined to be their index among the kernel arguments of the same
24 /// type and access qualifier.
25 //===----------------------------------------------------------------------===//
26
27 #include "AMDGPU.h"
28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallVector.h"
31 #include "llvm/Analysis/Passes.h"
32 #include "llvm/IR/Constants.h"
33 #include "llvm/IR/Function.h"
34 #include "llvm/IR/Instructions.h"
35 #include "llvm/IR/Module.h"
36 #include "llvm/Transforms/Utils/Cloning.h"
37
38 using namespace llvm;
39
40 namespace {
41
42 StringRef GetImageSizeFunc = "llvm.OpenCL.image.get.size";
43 StringRef GetImageFormatFunc = "llvm.OpenCL.image.get.format";
44 StringRef GetImageResourceIDFunc = "llvm.OpenCL.image.get.resource.id";
45 StringRef GetSamplerResourceIDFunc = "llvm.OpenCL.sampler.get.resource.id";
46
47 StringRef ImageSizeArgMDType = "__llvm_image_size";
48 StringRef ImageFormatArgMDType = "__llvm_image_format";
49
50 StringRef KernelsMDNodeName = "opencl.kernels";
51 StringRef KernelArgMDNodeNames[] = {
52 "kernel_arg_addr_space",
53 "kernel_arg_access_qual",
54 "kernel_arg_type",
55 "kernel_arg_base_type",
56 "kernel_arg_type_qual"};
57 const unsigned NumKernelArgMDNodes = 5;
58
59 typedef SmallVector<Metadata *, 8> MDVector;
60 struct KernelArgMD {
61 MDVector ArgVector[NumKernelArgMDNodes];
62 };
63
64 } // end anonymous namespace
65
66 static inline bool
IsImageType(StringRef TypeString)67 IsImageType(StringRef TypeString) {
68 return TypeString == "image2d_t" || TypeString == "image3d_t";
69 }
70
71 static inline bool
IsSamplerType(StringRef TypeString)72 IsSamplerType(StringRef TypeString) {
73 return TypeString == "sampler_t";
74 }
75
76 static Function *
GetFunctionFromMDNode(MDNode * Node)77 GetFunctionFromMDNode(MDNode *Node) {
78 if (!Node)
79 return nullptr;
80
81 size_t NumOps = Node->getNumOperands();
82 if (NumOps != NumKernelArgMDNodes + 1)
83 return nullptr;
84
85 auto F = mdconst::dyn_extract<Function>(Node->getOperand(0));
86 if (!F)
87 return nullptr;
88
89 // Sanity checks.
90 size_t ExpectNumArgNodeOps = F->arg_size() + 1;
91 for (size_t i = 0; i < NumKernelArgMDNodes; ++i) {
92 MDNode *ArgNode = dyn_cast_or_null<MDNode>(Node->getOperand(i + 1));
93 if (ArgNode->getNumOperands() != ExpectNumArgNodeOps)
94 return nullptr;
95 if (!ArgNode->getOperand(0))
96 return nullptr;
97
98 // FIXME: It should be possible to do image lowering when some metadata
99 // args missing or not in the expected order.
100 MDString *StringNode = dyn_cast<MDString>(ArgNode->getOperand(0));
101 if (!StringNode || StringNode->getString() != KernelArgMDNodeNames[i])
102 return nullptr;
103 }
104
105 return F;
106 }
107
108 static StringRef
AccessQualFromMD(MDNode * KernelMDNode,unsigned ArgIdx)109 AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
110 MDNode *ArgAQNode = cast<MDNode>(KernelMDNode->getOperand(2));
111 return cast<MDString>(ArgAQNode->getOperand(ArgIdx + 1))->getString();
112 }
113
114 static StringRef
ArgTypeFromMD(MDNode * KernelMDNode,unsigned ArgIdx)115 ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
116 MDNode *ArgTypeNode = cast<MDNode>(KernelMDNode->getOperand(3));
117 return cast<MDString>(ArgTypeNode->getOperand(ArgIdx + 1))->getString();
118 }
119
120 static MDVector
GetArgMD(MDNode * KernelMDNode,unsigned OpIdx)121 GetArgMD(MDNode *KernelMDNode, unsigned OpIdx) {
122 MDVector Res;
123 for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
124 MDNode *Node = cast<MDNode>(KernelMDNode->getOperand(i + 1));
125 Res.push_back(Node->getOperand(OpIdx));
126 }
127 return Res;
128 }
129
130 static void
PushArgMD(KernelArgMD & MD,const MDVector & V)131 PushArgMD(KernelArgMD &MD, const MDVector &V) {
132 assert(V.size() == NumKernelArgMDNodes);
133 for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
134 MD.ArgVector[i].push_back(V[i]);
135 }
136 }
137
138 namespace {
139
140 class AMDGPUOpenCLImageTypeLoweringPass : public ModulePass {
141 static char ID;
142
143 LLVMContext *Context;
144 Type *Int32Type;
145 Type *ImageSizeType;
146 Type *ImageFormatType;
147 SmallVector<Instruction *, 4> InstsToErase;
148
replaceImageUses(Argument & ImageArg,uint32_t ResourceID,Argument & ImageSizeArg,Argument & ImageFormatArg)149 bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID,
150 Argument &ImageSizeArg,
151 Argument &ImageFormatArg) {
152 bool Modified = false;
153
154 for (auto &Use : ImageArg.uses()) {
155 auto Inst = dyn_cast<CallInst>(Use.getUser());
156 if (!Inst) {
157 continue;
158 }
159
160 Function *F = Inst->getCalledFunction();
161 if (!F)
162 continue;
163
164 Value *Replacement = nullptr;
165 StringRef Name = F->getName();
166 if (Name.startswith(GetImageResourceIDFunc)) {
167 Replacement = ConstantInt::get(Int32Type, ResourceID);
168 } else if (Name.startswith(GetImageSizeFunc)) {
169 Replacement = &ImageSizeArg;
170 } else if (Name.startswith(GetImageFormatFunc)) {
171 Replacement = &ImageFormatArg;
172 } else {
173 continue;
174 }
175
176 Inst->replaceAllUsesWith(Replacement);
177 InstsToErase.push_back(Inst);
178 Modified = true;
179 }
180
181 return Modified;
182 }
183
replaceSamplerUses(Argument & SamplerArg,uint32_t ResourceID)184 bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) {
185 bool Modified = false;
186
187 for (const auto &Use : SamplerArg.uses()) {
188 auto Inst = dyn_cast<CallInst>(Use.getUser());
189 if (!Inst) {
190 continue;
191 }
192
193 Function *F = Inst->getCalledFunction();
194 if (!F)
195 continue;
196
197 Value *Replacement = nullptr;
198 StringRef Name = F->getName();
199 if (Name == GetSamplerResourceIDFunc) {
200 Replacement = ConstantInt::get(Int32Type, ResourceID);
201 } else {
202 continue;
203 }
204
205 Inst->replaceAllUsesWith(Replacement);
206 InstsToErase.push_back(Inst);
207 Modified = true;
208 }
209
210 return Modified;
211 }
212
replaceImageAndSamplerUses(Function * F,MDNode * KernelMDNode)213 bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) {
214 uint32_t NumReadOnlyImageArgs = 0;
215 uint32_t NumWriteOnlyImageArgs = 0;
216 uint32_t NumSamplerArgs = 0;
217
218 bool Modified = false;
219 InstsToErase.clear();
220 for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) {
221 Argument &Arg = *ArgI;
222 StringRef Type = ArgTypeFromMD(KernelMDNode, Arg.getArgNo());
223
224 // Handle image types.
225 if (IsImageType(Type)) {
226 StringRef AccessQual = AccessQualFromMD(KernelMDNode, Arg.getArgNo());
227 uint32_t ResourceID;
228 if (AccessQual == "read_only") {
229 ResourceID = NumReadOnlyImageArgs++;
230 } else if (AccessQual == "write_only") {
231 ResourceID = NumWriteOnlyImageArgs++;
232 } else {
233 llvm_unreachable("Wrong image access qualifier.");
234 }
235
236 Argument &SizeArg = *(++ArgI);
237 Argument &FormatArg = *(++ArgI);
238 Modified |= replaceImageUses(Arg, ResourceID, SizeArg, FormatArg);
239
240 // Handle sampler type.
241 } else if (IsSamplerType(Type)) {
242 uint32_t ResourceID = NumSamplerArgs++;
243 Modified |= replaceSamplerUses(Arg, ResourceID);
244 }
245 }
246 for (unsigned i = 0; i < InstsToErase.size(); ++i) {
247 InstsToErase[i]->eraseFromParent();
248 }
249
250 return Modified;
251 }
252
253 std::tuple<Function *, MDNode *>
addImplicitArgs(Function * F,MDNode * KernelMDNode)254 addImplicitArgs(Function *F, MDNode *KernelMDNode) {
255 bool Modified = false;
256
257 FunctionType *FT = F->getFunctionType();
258 SmallVector<Type *, 8> ArgTypes;
259
260 // Metadata operands for new MDNode.
261 KernelArgMD NewArgMDs;
262 PushArgMD(NewArgMDs, GetArgMD(KernelMDNode, 0));
263
264 // Add implicit arguments to the signature.
265 for (unsigned i = 0; i < FT->getNumParams(); ++i) {
266 ArgTypes.push_back(FT->getParamType(i));
267 MDVector ArgMD = GetArgMD(KernelMDNode, i + 1);
268 PushArgMD(NewArgMDs, ArgMD);
269
270 if (!IsImageType(ArgTypeFromMD(KernelMDNode, i)))
271 continue;
272
273 // Add size implicit argument.
274 ArgTypes.push_back(ImageSizeType);
275 ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageSizeArgMDType);
276 PushArgMD(NewArgMDs, ArgMD);
277
278 // Add format implicit argument.
279 ArgTypes.push_back(ImageFormatType);
280 ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageFormatArgMDType);
281 PushArgMD(NewArgMDs, ArgMD);
282
283 Modified = true;
284 }
285 if (!Modified) {
286 return std::make_tuple(nullptr, nullptr);
287 }
288
289 // Create function with new signature and clone the old body into it.
290 auto NewFT = FunctionType::get(FT->getReturnType(), ArgTypes, false);
291 auto NewF = Function::Create(NewFT, F->getLinkage(), F->getName());
292 ValueToValueMapTy VMap;
293 auto NewFArgIt = NewF->arg_begin();
294 for (auto &Arg: F->args()) {
295 auto ArgName = Arg.getName();
296 NewFArgIt->setName(ArgName);
297 VMap[&Arg] = &(*NewFArgIt++);
298 if (IsImageType(ArgTypeFromMD(KernelMDNode, Arg.getArgNo()))) {
299 (NewFArgIt++)->setName(Twine("__size_") + ArgName);
300 (NewFArgIt++)->setName(Twine("__format_") + ArgName);
301 }
302 }
303 SmallVector<ReturnInst*, 8> Returns;
304 CloneFunctionInto(NewF, F, VMap, /*ModuleLevelChanges=*/false, Returns);
305
306 // Build new MDNode.
307 SmallVector<llvm::Metadata *, 6> KernelMDArgs;
308 KernelMDArgs.push_back(ConstantAsMetadata::get(NewF));
309 for (unsigned i = 0; i < NumKernelArgMDNodes; ++i)
310 KernelMDArgs.push_back(MDNode::get(*Context, NewArgMDs.ArgVector[i]));
311 MDNode *NewMDNode = MDNode::get(*Context, KernelMDArgs);
312
313 return std::make_tuple(NewF, NewMDNode);
314 }
315
transformKernels(Module & M)316 bool transformKernels(Module &M) {
317 NamedMDNode *KernelsMDNode = M.getNamedMetadata(KernelsMDNodeName);
318 if (!KernelsMDNode)
319 return false;
320
321 bool Modified = false;
322 for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) {
323 MDNode *KernelMDNode = KernelsMDNode->getOperand(i);
324 Function *F = GetFunctionFromMDNode(KernelMDNode);
325 if (!F)
326 continue;
327
328 Function *NewF;
329 MDNode *NewMDNode;
330 std::tie(NewF, NewMDNode) = addImplicitArgs(F, KernelMDNode);
331 if (NewF) {
332 // Replace old function and metadata with new ones.
333 F->eraseFromParent();
334 M.getFunctionList().push_back(NewF);
335 M.getOrInsertFunction(NewF->getName(), NewF->getFunctionType(),
336 NewF->getAttributes());
337 KernelsMDNode->setOperand(i, NewMDNode);
338
339 F = NewF;
340 KernelMDNode = NewMDNode;
341 Modified = true;
342 }
343
344 Modified |= replaceImageAndSamplerUses(F, KernelMDNode);
345 }
346
347 return Modified;
348 }
349
350 public:
AMDGPUOpenCLImageTypeLoweringPass()351 AMDGPUOpenCLImageTypeLoweringPass() : ModulePass(ID) {}
352
runOnModule(Module & M)353 bool runOnModule(Module &M) override {
354 Context = &M.getContext();
355 Int32Type = Type::getInt32Ty(M.getContext());
356 ImageSizeType = ArrayType::get(Int32Type, 3);
357 ImageFormatType = ArrayType::get(Int32Type, 2);
358
359 return transformKernels(M);
360 }
361
getPassName() const362 const char *getPassName() const override {
363 return "AMDGPU OpenCL Image Type Pass";
364 }
365 };
366
367 char AMDGPUOpenCLImageTypeLoweringPass::ID = 0;
368
369 } // end anonymous namespace
370
createAMDGPUOpenCLImageTypeLoweringPass()371 ModulePass *llvm::createAMDGPUOpenCLImageTypeLoweringPass() {
372 return new AMDGPUOpenCLImageTypeLoweringPass();
373 }
374