1 //===- R600OpenCLImageTypeLoweringPass.cpp ------------------------------===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 /// \file
11 /// This pass resolves calls to OpenCL image attribute, image resource ID and
12 /// sampler resource ID getter functions.
13 ///
14 /// Image attributes (size and format) are expected to be passed to the kernel
15 /// as kernel arguments immediately following the image argument itself,
16 /// therefore this pass adds image size and format arguments to the kernel
17 /// functions in the module. The kernel functions with image arguments are
18 /// re-created using the new signature. The new arguments are added to the
19 /// kernel metadata with kernel_arg_type set to "image_size" or "image_format".
20 /// Note: this pass may invalidate pointers to functions.
21 ///
22 /// Resource IDs of read-only images, write-only images and samplers are
23 /// defined to be their index among the kernel arguments of the same
24 /// type and access qualifier.
25 //
26 //===----------------------------------------------------------------------===//
27
28 #include "AMDGPU.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/StringRef.h"
31 #include "llvm/ADT/Twine.h"
32 #include "llvm/IR/Argument.h"
33 #include "llvm/IR/DerivedTypes.h"
34 #include "llvm/IR/Constants.h"
35 #include "llvm/IR/Function.h"
36 #include "llvm/IR/Instruction.h"
37 #include "llvm/IR/Instructions.h"
38 #include "llvm/IR/Metadata.h"
39 #include "llvm/IR/Module.h"
40 #include "llvm/IR/Type.h"
41 #include "llvm/IR/Use.h"
42 #include "llvm/IR/User.h"
43 #include "llvm/Pass.h"
44 #include "llvm/Support/Casting.h"
45 #include "llvm/Support/ErrorHandling.h"
46 #include "llvm/Transforms/Utils/Cloning.h"
47 #include "llvm/Transforms/Utils/ValueMapper.h"
48 #include <cassert>
49 #include <cstddef>
50 #include <cstdint>
51 #include <tuple>
52
53 using namespace llvm;
54
55 static StringRef GetImageSizeFunc = "llvm.OpenCL.image.get.size";
56 static StringRef GetImageFormatFunc = "llvm.OpenCL.image.get.format";
57 static StringRef GetImageResourceIDFunc = "llvm.OpenCL.image.get.resource.id";
58 static StringRef GetSamplerResourceIDFunc =
59 "llvm.OpenCL.sampler.get.resource.id";
60
61 static StringRef ImageSizeArgMDType = "__llvm_image_size";
62 static StringRef ImageFormatArgMDType = "__llvm_image_format";
63
64 static StringRef KernelsMDNodeName = "opencl.kernels";
65 static StringRef KernelArgMDNodeNames[] = {
66 "kernel_arg_addr_space",
67 "kernel_arg_access_qual",
68 "kernel_arg_type",
69 "kernel_arg_base_type",
70 "kernel_arg_type_qual"};
71 static const unsigned NumKernelArgMDNodes = 5;
72
73 namespace {
74
75 using MDVector = SmallVector<Metadata *, 8>;
76 struct KernelArgMD {
77 MDVector ArgVector[NumKernelArgMDNodes];
78 };
79
80 } // end anonymous namespace
81
82 static inline bool
IsImageType(StringRef TypeString)83 IsImageType(StringRef TypeString) {
84 return TypeString == "image2d_t" || TypeString == "image3d_t";
85 }
86
87 static inline bool
IsSamplerType(StringRef TypeString)88 IsSamplerType(StringRef TypeString) {
89 return TypeString == "sampler_t";
90 }
91
92 static Function *
GetFunctionFromMDNode(MDNode * Node)93 GetFunctionFromMDNode(MDNode *Node) {
94 if (!Node)
95 return nullptr;
96
97 size_t NumOps = Node->getNumOperands();
98 if (NumOps != NumKernelArgMDNodes + 1)
99 return nullptr;
100
101 auto F = mdconst::dyn_extract<Function>(Node->getOperand(0));
102 if (!F)
103 return nullptr;
104
105 // Sanity checks.
106 size_t ExpectNumArgNodeOps = F->arg_size() + 1;
107 for (size_t i = 0; i < NumKernelArgMDNodes; ++i) {
108 MDNode *ArgNode = dyn_cast_or_null<MDNode>(Node->getOperand(i + 1));
109 if (ArgNode->getNumOperands() != ExpectNumArgNodeOps)
110 return nullptr;
111 if (!ArgNode->getOperand(0))
112 return nullptr;
113
114 // FIXME: It should be possible to do image lowering when some metadata
115 // args missing or not in the expected order.
116 MDString *StringNode = dyn_cast<MDString>(ArgNode->getOperand(0));
117 if (!StringNode || StringNode->getString() != KernelArgMDNodeNames[i])
118 return nullptr;
119 }
120
121 return F;
122 }
123
124 static StringRef
AccessQualFromMD(MDNode * KernelMDNode,unsigned ArgIdx)125 AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
126 MDNode *ArgAQNode = cast<MDNode>(KernelMDNode->getOperand(2));
127 return cast<MDString>(ArgAQNode->getOperand(ArgIdx + 1))->getString();
128 }
129
130 static StringRef
ArgTypeFromMD(MDNode * KernelMDNode,unsigned ArgIdx)131 ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
132 MDNode *ArgTypeNode = cast<MDNode>(KernelMDNode->getOperand(3));
133 return cast<MDString>(ArgTypeNode->getOperand(ArgIdx + 1))->getString();
134 }
135
136 static MDVector
GetArgMD(MDNode * KernelMDNode,unsigned OpIdx)137 GetArgMD(MDNode *KernelMDNode, unsigned OpIdx) {
138 MDVector Res;
139 for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
140 MDNode *Node = cast<MDNode>(KernelMDNode->getOperand(i + 1));
141 Res.push_back(Node->getOperand(OpIdx));
142 }
143 return Res;
144 }
145
146 static void
PushArgMD(KernelArgMD & MD,const MDVector & V)147 PushArgMD(KernelArgMD &MD, const MDVector &V) {
148 assert(V.size() == NumKernelArgMDNodes);
149 for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
150 MD.ArgVector[i].push_back(V[i]);
151 }
152 }
153
154 namespace {
155
156 class R600OpenCLImageTypeLoweringPass : public ModulePass {
157 static char ID;
158
159 LLVMContext *Context;
160 Type *Int32Type;
161 Type *ImageSizeType;
162 Type *ImageFormatType;
163 SmallVector<Instruction *, 4> InstsToErase;
164
replaceImageUses(Argument & ImageArg,uint32_t ResourceID,Argument & ImageSizeArg,Argument & ImageFormatArg)165 bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID,
166 Argument &ImageSizeArg,
167 Argument &ImageFormatArg) {
168 bool Modified = false;
169
170 for (auto &Use : ImageArg.uses()) {
171 auto Inst = dyn_cast<CallInst>(Use.getUser());
172 if (!Inst) {
173 continue;
174 }
175
176 Function *F = Inst->getCalledFunction();
177 if (!F)
178 continue;
179
180 Value *Replacement = nullptr;
181 StringRef Name = F->getName();
182 if (Name.startswith(GetImageResourceIDFunc)) {
183 Replacement = ConstantInt::get(Int32Type, ResourceID);
184 } else if (Name.startswith(GetImageSizeFunc)) {
185 Replacement = &ImageSizeArg;
186 } else if (Name.startswith(GetImageFormatFunc)) {
187 Replacement = &ImageFormatArg;
188 } else {
189 continue;
190 }
191
192 Inst->replaceAllUsesWith(Replacement);
193 InstsToErase.push_back(Inst);
194 Modified = true;
195 }
196
197 return Modified;
198 }
199
replaceSamplerUses(Argument & SamplerArg,uint32_t ResourceID)200 bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) {
201 bool Modified = false;
202
203 for (const auto &Use : SamplerArg.uses()) {
204 auto Inst = dyn_cast<CallInst>(Use.getUser());
205 if (!Inst) {
206 continue;
207 }
208
209 Function *F = Inst->getCalledFunction();
210 if (!F)
211 continue;
212
213 Value *Replacement = nullptr;
214 StringRef Name = F->getName();
215 if (Name == GetSamplerResourceIDFunc) {
216 Replacement = ConstantInt::get(Int32Type, ResourceID);
217 } else {
218 continue;
219 }
220
221 Inst->replaceAllUsesWith(Replacement);
222 InstsToErase.push_back(Inst);
223 Modified = true;
224 }
225
226 return Modified;
227 }
228
replaceImageAndSamplerUses(Function * F,MDNode * KernelMDNode)229 bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) {
230 uint32_t NumReadOnlyImageArgs = 0;
231 uint32_t NumWriteOnlyImageArgs = 0;
232 uint32_t NumSamplerArgs = 0;
233
234 bool Modified = false;
235 InstsToErase.clear();
236 for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) {
237 Argument &Arg = *ArgI;
238 StringRef Type = ArgTypeFromMD(KernelMDNode, Arg.getArgNo());
239
240 // Handle image types.
241 if (IsImageType(Type)) {
242 StringRef AccessQual = AccessQualFromMD(KernelMDNode, Arg.getArgNo());
243 uint32_t ResourceID;
244 if (AccessQual == "read_only") {
245 ResourceID = NumReadOnlyImageArgs++;
246 } else if (AccessQual == "write_only") {
247 ResourceID = NumWriteOnlyImageArgs++;
248 } else {
249 llvm_unreachable("Wrong image access qualifier.");
250 }
251
252 Argument &SizeArg = *(++ArgI);
253 Argument &FormatArg = *(++ArgI);
254 Modified |= replaceImageUses(Arg, ResourceID, SizeArg, FormatArg);
255
256 // Handle sampler type.
257 } else if (IsSamplerType(Type)) {
258 uint32_t ResourceID = NumSamplerArgs++;
259 Modified |= replaceSamplerUses(Arg, ResourceID);
260 }
261 }
262 for (unsigned i = 0; i < InstsToErase.size(); ++i) {
263 InstsToErase[i]->eraseFromParent();
264 }
265
266 return Modified;
267 }
268
269 std::tuple<Function *, MDNode *>
addImplicitArgs(Function * F,MDNode * KernelMDNode)270 addImplicitArgs(Function *F, MDNode *KernelMDNode) {
271 bool Modified = false;
272
273 FunctionType *FT = F->getFunctionType();
274 SmallVector<Type *, 8> ArgTypes;
275
276 // Metadata operands for new MDNode.
277 KernelArgMD NewArgMDs;
278 PushArgMD(NewArgMDs, GetArgMD(KernelMDNode, 0));
279
280 // Add implicit arguments to the signature.
281 for (unsigned i = 0; i < FT->getNumParams(); ++i) {
282 ArgTypes.push_back(FT->getParamType(i));
283 MDVector ArgMD = GetArgMD(KernelMDNode, i + 1);
284 PushArgMD(NewArgMDs, ArgMD);
285
286 if (!IsImageType(ArgTypeFromMD(KernelMDNode, i)))
287 continue;
288
289 // Add size implicit argument.
290 ArgTypes.push_back(ImageSizeType);
291 ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageSizeArgMDType);
292 PushArgMD(NewArgMDs, ArgMD);
293
294 // Add format implicit argument.
295 ArgTypes.push_back(ImageFormatType);
296 ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageFormatArgMDType);
297 PushArgMD(NewArgMDs, ArgMD);
298
299 Modified = true;
300 }
301 if (!Modified) {
302 return std::make_tuple(nullptr, nullptr);
303 }
304
305 // Create function with new signature and clone the old body into it.
306 auto NewFT = FunctionType::get(FT->getReturnType(), ArgTypes, false);
307 auto NewF = Function::Create(NewFT, F->getLinkage(), F->getName());
308 ValueToValueMapTy VMap;
309 auto NewFArgIt = NewF->arg_begin();
310 for (auto &Arg: F->args()) {
311 auto ArgName = Arg.getName();
312 NewFArgIt->setName(ArgName);
313 VMap[&Arg] = &(*NewFArgIt++);
314 if (IsImageType(ArgTypeFromMD(KernelMDNode, Arg.getArgNo()))) {
315 (NewFArgIt++)->setName(Twine("__size_") + ArgName);
316 (NewFArgIt++)->setName(Twine("__format_") + ArgName);
317 }
318 }
319 SmallVector<ReturnInst*, 8> Returns;
320 CloneFunctionInto(NewF, F, VMap, /*ModuleLevelChanges=*/false, Returns);
321
322 // Build new MDNode.
323 SmallVector<Metadata *, 6> KernelMDArgs;
324 KernelMDArgs.push_back(ConstantAsMetadata::get(NewF));
325 for (unsigned i = 0; i < NumKernelArgMDNodes; ++i)
326 KernelMDArgs.push_back(MDNode::get(*Context, NewArgMDs.ArgVector[i]));
327 MDNode *NewMDNode = MDNode::get(*Context, KernelMDArgs);
328
329 return std::make_tuple(NewF, NewMDNode);
330 }
331
transformKernels(Module & M)332 bool transformKernels(Module &M) {
333 NamedMDNode *KernelsMDNode = M.getNamedMetadata(KernelsMDNodeName);
334 if (!KernelsMDNode)
335 return false;
336
337 bool Modified = false;
338 for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) {
339 MDNode *KernelMDNode = KernelsMDNode->getOperand(i);
340 Function *F = GetFunctionFromMDNode(KernelMDNode);
341 if (!F)
342 continue;
343
344 Function *NewF;
345 MDNode *NewMDNode;
346 std::tie(NewF, NewMDNode) = addImplicitArgs(F, KernelMDNode);
347 if (NewF) {
348 // Replace old function and metadata with new ones.
349 F->eraseFromParent();
350 M.getFunctionList().push_back(NewF);
351 M.getOrInsertFunction(NewF->getName(), NewF->getFunctionType(),
352 NewF->getAttributes());
353 KernelsMDNode->setOperand(i, NewMDNode);
354
355 F = NewF;
356 KernelMDNode = NewMDNode;
357 Modified = true;
358 }
359
360 Modified |= replaceImageAndSamplerUses(F, KernelMDNode);
361 }
362
363 return Modified;
364 }
365
366 public:
R600OpenCLImageTypeLoweringPass()367 R600OpenCLImageTypeLoweringPass() : ModulePass(ID) {}
368
runOnModule(Module & M)369 bool runOnModule(Module &M) override {
370 Context = &M.getContext();
371 Int32Type = Type::getInt32Ty(M.getContext());
372 ImageSizeType = ArrayType::get(Int32Type, 3);
373 ImageFormatType = ArrayType::get(Int32Type, 2);
374
375 return transformKernels(M);
376 }
377
getPassName() const378 StringRef getPassName() const override {
379 return "R600 OpenCL Image Type Pass";
380 }
381 };
382
383 } // end anonymous namespace
384
385 char R600OpenCLImageTypeLoweringPass::ID = 0;
386
createR600OpenCLImageTypeLoweringPass()387 ModulePass *llvm::createR600OpenCLImageTypeLoweringPass() {
388 return new R600OpenCLImageTypeLoweringPass();
389 }
390