1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert vulkan launch call into a sequence of
10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we
11 // don't expose separate external functions in IR for each of them, instead we
12 // expose a few external functions to wrapper libraries which manages Vulkan
13 // runtime.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "../PassDetail.h"
18 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinOps.h"
23 
24 #include "llvm/ADT/SmallString.h"
25 #include "llvm/Support/FormatVariadic.h"
26 
27 using namespace mlir;
28 
29 static constexpr const char *kCInterfaceVulkanLaunch =
30     "_mlir_ciface_vulkanLaunch";
31 static constexpr const char *kDeinitVulkan = "deinitVulkan";
32 static constexpr const char *kRunOnVulkan = "runOnVulkan";
33 static constexpr const char *kInitVulkan = "initVulkan";
34 static constexpr const char *kSetBinaryShader = "setBinaryShader";
35 static constexpr const char *kSetEntryPoint = "setEntryPoint";
36 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups";
37 static constexpr const char *kSPIRVBinary = "SPIRV_BIN";
38 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
39 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
40 static constexpr const char *kVulkanLaunch = "vulkanLaunch";
41 
42 namespace {
43 
44 /// A pass to convert vulkan launch call op into a sequence of Vulkan
45 /// runtime calls in the following order:
46 ///
47 /// * initVulkan           -- initializes vulkan runtime
48 /// * bindMemRef           -- binds memref
49 /// * setBinaryShader      -- sets the binary shader data
50 /// * setEntryPoint        -- sets the entry point name
51 /// * setNumWorkGroups     -- sets the number of a local workgroups
52 /// * runOnVulkan          -- runs vulkan runtime
53 /// * deinitVulkan         -- deinitializes vulkan runtime
54 ///
55 class VulkanLaunchFuncToVulkanCallsPass
56     : public ConvertVulkanLaunchFuncToVulkanCallsBase<
57           VulkanLaunchFuncToVulkanCallsPass> {
58 private:
initializeCachedTypes()59   void initializeCachedTypes() {
60     llvmFloatType = LLVM::LLVMType::getFloatTy(&getContext());
61     llvmVoidType = LLVM::LLVMType::getVoidTy(&getContext());
62     llvmPointerType = LLVM::LLVMType::getInt8PtrTy(&getContext());
63     llvmInt32Type = LLVM::LLVMType::getInt32Ty(&getContext());
64     llvmInt64Type = LLVM::LLVMType::getInt64Ty(&getContext());
65   }
66 
getMemRefType(uint32_t rank,LLVM::LLVMType elemenType)67   LLVM::LLVMType getMemRefType(uint32_t rank, LLVM::LLVMType elemenType) {
68     // According to the MLIR doc memref argument is converted into a
69     // pointer-to-struct argument of type:
70     // template <typename Elem, size_t Rank>
71     // struct {
72     //   Elem *allocated;
73     //   Elem *aligned;
74     //   int64_t offset;
75     //   int64_t sizes[Rank]; // omitted when rank == 0
76     //   int64_t strides[Rank]; // omitted when rank == 0
77     // };
78     auto llvmPtrToElementType = elemenType.getPointerTo();
79     auto llvmArrayRankElementSizeType =
80         LLVM::LLVMType::getArrayTy(getInt64Type(), rank);
81 
82     // Create a type
83     // `!llvm<"{ `element-type`*, `element-type`*, i64,
84     // [`rank` x i64], [`rank` x i64]}">`.
85     return LLVM::LLVMType::getStructTy(
86         &getContext(),
87         {llvmPtrToElementType, llvmPtrToElementType, getInt64Type(),
88          llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
89   }
90 
getVoidType()91   LLVM::LLVMType getVoidType() { return llvmVoidType; }
getPointerType()92   LLVM::LLVMType getPointerType() { return llvmPointerType; }
getInt32Type()93   LLVM::LLVMType getInt32Type() { return llvmInt32Type; }
getInt64Type()94   LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
95 
96   /// Creates an LLVM global for the given `name`.
97   Value createEntryPointNameConstant(StringRef name, Location loc,
98                                      OpBuilder &builder);
99 
100   /// Declares all needed runtime functions.
101   void declareVulkanFunctions(Location loc);
102 
103   /// Checks whether the given LLVM::CallOp is a vulkan launch call op.
isVulkanLaunchCallOp(LLVM::CallOp callOp)104   bool isVulkanLaunchCallOp(LLVM::CallOp callOp) {
105     return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch &&
106             callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands);
107   }
108 
109   /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call
110   /// op.
isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp)111   bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) {
112     return (callOp.callee() &&
113             callOp.callee().getValue() == kCInterfaceVulkanLaunch &&
114             callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands);
115   }
116 
117   /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan
118   /// runtime calls.
119   void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp);
120 
121   /// Creates call to `bindMemRef` for each memref operand.
122   void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp,
123                              Value vulkanRuntime);
124 
125   /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`.
126   void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp);
127 
128   /// Deduces a rank and element type from the given 'ptrToMemRefDescriptor`.
129   LogicalResult deduceMemRefRankAndType(Value ptrToMemRefDescriptor,
130                                         uint32_t &rank, LLVM::LLVMType &type);
131 
132   /// Returns a string representation from the given `type`.
stringifyType(LLVM::LLVMType type)133   StringRef stringifyType(LLVM::LLVMType type) {
134     if (type.isFloatTy())
135       return "Float";
136     if (type.isHalfTy())
137       return "Half";
138     if (type.isIntegerTy(32))
139       return "Int32";
140     if (type.isIntegerTy(16))
141       return "Int16";
142     if (type.isIntegerTy(8))
143       return "Int8";
144 
145     llvm_unreachable("unsupported type");
146   }
147 
148 public:
149   void runOnOperation() override;
150 
151 private:
152   LLVM::LLVMType llvmFloatType;
153   LLVM::LLVMType llvmVoidType;
154   LLVM::LLVMType llvmPointerType;
155   LLVM::LLVMType llvmInt32Type;
156   LLVM::LLVMType llvmInt64Type;
157 
158   // TODO: Use an associative array to support multiple vulkan launch calls.
159   std::pair<StringAttr, StringAttr> spirvAttributes;
160   /// The number of vulkan launch configuration operands, placed at the leading
161   /// positions of the operand list.
162   static constexpr unsigned kVulkanLaunchNumConfigOperands = 3;
163 };
164 
165 } // anonymous namespace
166 
runOnOperation()167 void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() {
168   initializeCachedTypes();
169 
170   // Collect SPIR-V attributes such as `spirv_blob` and
171   // `spirv_entry_point_name`.
172   getOperation().walk([this](LLVM::CallOp op) {
173     if (isVulkanLaunchCallOp(op))
174       collectSPIRVAttributes(op);
175   });
176 
177   // Convert vulkan launch call op into a sequence of Vulkan runtime calls.
178   getOperation().walk([this](LLVM::CallOp op) {
179     if (isCInterfaceVulkanLaunchCallOp(op))
180       translateVulkanLaunchCall(op);
181   });
182 }
183 
collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp)184 void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes(
185     LLVM::CallOp vulkanLaunchCallOp) {
186   // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes
187   // for the given vulkan launch call.
188   auto spirvBlobAttr =
189       vulkanLaunchCallOp->getAttrOfType<StringAttr>(kSPIRVBlobAttrName);
190   if (!spirvBlobAttr) {
191     vulkanLaunchCallOp.emitError()
192         << "missing " << kSPIRVBlobAttrName << " attribute";
193     return signalPassFailure();
194   }
195 
196   auto spirvEntryPointNameAttr =
197       vulkanLaunchCallOp->getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName);
198   if (!spirvEntryPointNameAttr) {
199     vulkanLaunchCallOp.emitError()
200         << "missing " << kSPIRVEntryPointAttrName << " attribute";
201     return signalPassFailure();
202   }
203 
204   spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr);
205 }
206 
createBindMemRefCalls(LLVM::CallOp cInterfaceVulkanLaunchCallOp,Value vulkanRuntime)207 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
208     LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) {
209   if (cInterfaceVulkanLaunchCallOp.getNumOperands() ==
210       kVulkanLaunchNumConfigOperands)
211     return;
212   OpBuilder builder(cInterfaceVulkanLaunchCallOp);
213   Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
214 
215   // Create LLVM constant for the descriptor set index.
216   // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV`
217   // pass does.
218   Value descriptorSet = builder.create<LLVM::ConstantOp>(
219       loc, getInt32Type(), builder.getI32IntegerAttr(0));
220 
221   for (auto en :
222        llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front(
223            kVulkanLaunchNumConfigOperands))) {
224     // Create LLVM constant for the descriptor binding index.
225     Value descriptorBinding = builder.create<LLVM::ConstantOp>(
226         loc, getInt32Type(), builder.getI32IntegerAttr(en.index()));
227 
228     auto ptrToMemRefDescriptor = en.value();
229     uint32_t rank = 0;
230     LLVM::LLVMType type;
231     if (failed(deduceMemRefRankAndType(ptrToMemRefDescriptor, rank, type))) {
232       cInterfaceVulkanLaunchCallOp.emitError()
233           << "invalid memref descriptor " << ptrToMemRefDescriptor.getType();
234       return signalPassFailure();
235     }
236 
237     auto symbolName =
238         llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str();
239     // Special case for fp16 type. Since it is not a supported type in C we use
240     // int16_t and bitcast the descriptor.
241     if (type.isHalfTy()) {
242       auto memRefTy =
243           getMemRefType(rank, LLVM::LLVMType::getInt16Ty(&getContext()));
244       ptrToMemRefDescriptor = builder.create<LLVM::BitcastOp>(
245           loc, memRefTy.getPointerTo(), ptrToMemRefDescriptor);
246     }
247     // Create call to `bindMemRef`.
248     builder.create<LLVM::CallOp>(
249         loc, TypeRange{getVoidType()},
250         builder.getSymbolRefAttr(
251             StringRef(symbolName.data(), symbolName.size())),
252         ValueRange{vulkanRuntime, descriptorSet, descriptorBinding,
253                    ptrToMemRefDescriptor});
254   }
255 }
256 
deduceMemRefRankAndType(Value ptrToMemRefDescriptor,uint32_t & rank,LLVM::LLVMType & type)257 LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType(
258     Value ptrToMemRefDescriptor, uint32_t &rank, LLVM::LLVMType &type) {
259   auto llvmPtrDescriptorTy =
260       ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMType>();
261   if (!llvmPtrDescriptorTy)
262     return failure();
263 
264   auto llvmDescriptorTy = llvmPtrDescriptorTy.getPointerElementTy();
265   // template <typename Elem, size_t Rank>
266   // struct {
267   //   Elem *allocated;
268   //   Elem *aligned;
269   //   int64_t offset;
270   //   int64_t sizes[Rank]; // omitted when rank == 0
271   //   int64_t strides[Rank]; // omitted when rank == 0
272   // };
273   if (!llvmDescriptorTy || !llvmDescriptorTy.isStructTy())
274     return failure();
275 
276   type = llvmDescriptorTy.getStructElementType(0).getPointerElementTy();
277   if (llvmDescriptorTy.getStructNumElements() == 3) {
278     rank = 0;
279     return success();
280   }
281   rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements();
282   return success();
283 }
284 
declareVulkanFunctions(Location loc)285 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
286   ModuleOp module = getOperation();
287   OpBuilder builder(module.getBody()->getTerminator());
288 
289   if (!module.lookupSymbol(kSetEntryPoint)) {
290     builder.create<LLVM::LLVMFuncOp>(
291         loc, kSetEntryPoint,
292         LLVM::LLVMType::getFunctionTy(getVoidType(),
293                                       {getPointerType(), getPointerType()},
294                                       /*isVarArg=*/false));
295   }
296 
297   if (!module.lookupSymbol(kSetNumWorkGroups)) {
298     builder.create<LLVM::LLVMFuncOp>(
299         loc, kSetNumWorkGroups,
300         LLVM::LLVMType::getFunctionTy(
301             getVoidType(),
302             {getPointerType(), getInt64Type(), getInt64Type(), getInt64Type()},
303             /*isVarArg=*/false));
304   }
305 
306   if (!module.lookupSymbol(kSetBinaryShader)) {
307     builder.create<LLVM::LLVMFuncOp>(
308         loc, kSetBinaryShader,
309         LLVM::LLVMType::getFunctionTy(
310             getVoidType(), {getPointerType(), getPointerType(), getInt32Type()},
311             /*isVarArg=*/false));
312   }
313 
314   if (!module.lookupSymbol(kRunOnVulkan)) {
315     builder.create<LLVM::LLVMFuncOp>(
316         loc, kRunOnVulkan,
317         LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
318                                       /*isVarArg=*/false));
319   }
320 
321   for (unsigned i = 1; i <= 3; i++) {
322     for (LLVM::LLVMType type : {LLVM::LLVMType::getFloatTy(&getContext()),
323                                 LLVM::LLVMType::getInt32Ty(&getContext()),
324                                 LLVM::LLVMType::getInt16Ty(&getContext()),
325                                 LLVM::LLVMType::getInt8Ty(&getContext()),
326                                 LLVM::LLVMType::getHalfTy(&getContext())}) {
327       std::string fnName = "bindMemRef" + std::to_string(i) + "D" +
328                            std::string(stringifyType(type));
329       if (type.isHalfTy())
330         type = LLVM::LLVMType::getInt16Ty(&getContext());
331       if (!module.lookupSymbol(fnName)) {
332         auto fnType = LLVM::LLVMType::getFunctionTy(
333             getVoidType(),
334             {getPointerType(), getInt32Type(), getInt32Type(),
335              getMemRefType(i, type).getPointerTo()},
336             /*isVarArg=*/false);
337         builder.create<LLVM::LLVMFuncOp>(loc, fnName, fnType);
338       }
339     }
340   }
341 
342   if (!module.lookupSymbol(kInitVulkan)) {
343     builder.create<LLVM::LLVMFuncOp>(
344         loc, kInitVulkan,
345         LLVM::LLVMType::getFunctionTy(getPointerType(), {},
346                                       /*isVarArg=*/false));
347   }
348 
349   if (!module.lookupSymbol(kDeinitVulkan)) {
350     builder.create<LLVM::LLVMFuncOp>(
351         loc, kDeinitVulkan,
352         LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
353                                       /*isVarArg=*/false));
354   }
355 }
356 
createEntryPointNameConstant(StringRef name,Location loc,OpBuilder & builder)357 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
358     StringRef name, Location loc, OpBuilder &builder) {
359   SmallString<16> shaderName(name.begin(), name.end());
360   // Append `\0` to follow C style string given that LLVM::createGlobalString()
361   // won't handle this directly for us.
362   shaderName.push_back('\0');
363 
364   std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
365   return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
366                                   shaderName, LLVM::Linkage::Internal);
367 }
368 
translateVulkanLaunchCall(LLVM::CallOp cInterfaceVulkanLaunchCallOp)369 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
370     LLVM::CallOp cInterfaceVulkanLaunchCallOp) {
371   OpBuilder builder(cInterfaceVulkanLaunchCallOp);
372   Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
373   // Create call to `initVulkan`.
374   auto initVulkanCall = builder.create<LLVM::CallOp>(
375       loc, TypeRange{getPointerType()}, builder.getSymbolRefAttr(kInitVulkan),
376       ValueRange{});
377   // The result of `initVulkan` function is a pointer to Vulkan runtime, we
378   // need to pass that pointer to each Vulkan runtime call.
379   auto vulkanRuntime = initVulkanCall.getResult(0);
380 
381   // Create LLVM global with SPIR-V binary data, so we can pass a pointer with
382   // that data to runtime call.
383   Value ptrToSPIRVBinary = LLVM::createGlobalString(
384       loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(),
385       LLVM::Linkage::Internal);
386 
387   // Create LLVM constant for the size of SPIR-V binary shader.
388   Value binarySize = builder.create<LLVM::ConstantOp>(
389       loc, getInt32Type(),
390       builder.getI32IntegerAttr(spirvAttributes.first.getValue().size()));
391 
392   // Create call to `bindMemRef` for each memref operand.
393   createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime);
394 
395   // Create call to `setBinaryShader` runtime function with the given pointer to
396   // SPIR-V binary and binary size.
397   builder.create<LLVM::CallOp>(
398       loc, TypeRange{getVoidType()}, builder.getSymbolRefAttr(kSetBinaryShader),
399       ValueRange{vulkanRuntime, ptrToSPIRVBinary, binarySize});
400   // Create LLVM global with entry point name.
401   Value entryPointName = createEntryPointNameConstant(
402       spirvAttributes.second.getValue(), loc, builder);
403   // Create call to `setEntryPoint` runtime function with the given pointer to
404   // entry point name.
405   builder.create<LLVM::CallOp>(loc, TypeRange{getVoidType()},
406                                builder.getSymbolRefAttr(kSetEntryPoint),
407                                ValueRange{vulkanRuntime, entryPointName});
408 
409   // Create number of local workgroup for each dimension.
410   builder.create<LLVM::CallOp>(
411       loc, TypeRange{getVoidType()},
412       builder.getSymbolRefAttr(kSetNumWorkGroups),
413       ValueRange{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0),
414                  cInterfaceVulkanLaunchCallOp.getOperand(1),
415                  cInterfaceVulkanLaunchCallOp.getOperand(2)});
416 
417   // Create call to `runOnVulkan` runtime function.
418   builder.create<LLVM::CallOp>(loc, TypeRange{getVoidType()},
419                                builder.getSymbolRefAttr(kRunOnVulkan),
420                                ValueRange{vulkanRuntime});
421 
422   // Create call to 'deinitVulkan' runtime function.
423   builder.create<LLVM::CallOp>(loc, TypeRange{getVoidType()},
424                                builder.getSymbolRefAttr(kDeinitVulkan),
425                                ValueRange{vulkanRuntime});
426 
427   // Declare runtime functions.
428   declareVulkanFunctions(loc);
429 
430   cInterfaceVulkanLaunchCallOp.erase();
431 }
432 
433 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createConvertVulkanLaunchFuncToVulkanCallsPass()434 mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() {
435   return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>();
436 }
437