1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert vulkan launch call into a sequence of
10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we
11 // don't expose separate external functions in IR for each of them, instead we
12 // expose a few external functions to wrapper libraries which manages Vulkan
13 // runtime.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #include "../PassDetail.h"
18 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinOps.h"
23
24 #include "llvm/ADT/SmallString.h"
25 #include "llvm/Support/FormatVariadic.h"
26
27 using namespace mlir;
28
29 static constexpr const char *kCInterfaceVulkanLaunch =
30 "_mlir_ciface_vulkanLaunch";
31 static constexpr const char *kDeinitVulkan = "deinitVulkan";
32 static constexpr const char *kRunOnVulkan = "runOnVulkan";
33 static constexpr const char *kInitVulkan = "initVulkan";
34 static constexpr const char *kSetBinaryShader = "setBinaryShader";
35 static constexpr const char *kSetEntryPoint = "setEntryPoint";
36 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups";
37 static constexpr const char *kSPIRVBinary = "SPIRV_BIN";
38 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
39 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
40 static constexpr const char *kVulkanLaunch = "vulkanLaunch";
41
42 namespace {
43
44 /// A pass to convert vulkan launch call op into a sequence of Vulkan
45 /// runtime calls in the following order:
46 ///
47 /// * initVulkan -- initializes vulkan runtime
48 /// * bindMemRef -- binds memref
49 /// * setBinaryShader -- sets the binary shader data
50 /// * setEntryPoint -- sets the entry point name
51 /// * setNumWorkGroups -- sets the number of a local workgroups
52 /// * runOnVulkan -- runs vulkan runtime
53 /// * deinitVulkan -- deinitializes vulkan runtime
54 ///
55 class VulkanLaunchFuncToVulkanCallsPass
56 : public ConvertVulkanLaunchFuncToVulkanCallsBase<
57 VulkanLaunchFuncToVulkanCallsPass> {
58 private:
initializeCachedTypes()59 void initializeCachedTypes() {
60 llvmFloatType = LLVM::LLVMType::getFloatTy(&getContext());
61 llvmVoidType = LLVM::LLVMType::getVoidTy(&getContext());
62 llvmPointerType = LLVM::LLVMType::getInt8PtrTy(&getContext());
63 llvmInt32Type = LLVM::LLVMType::getInt32Ty(&getContext());
64 llvmInt64Type = LLVM::LLVMType::getInt64Ty(&getContext());
65 }
66
getMemRefType(uint32_t rank,LLVM::LLVMType elemenType)67 LLVM::LLVMType getMemRefType(uint32_t rank, LLVM::LLVMType elemenType) {
68 // According to the MLIR doc memref argument is converted into a
69 // pointer-to-struct argument of type:
70 // template <typename Elem, size_t Rank>
71 // struct {
72 // Elem *allocated;
73 // Elem *aligned;
74 // int64_t offset;
75 // int64_t sizes[Rank]; // omitted when rank == 0
76 // int64_t strides[Rank]; // omitted when rank == 0
77 // };
78 auto llvmPtrToElementType = elemenType.getPointerTo();
79 auto llvmArrayRankElementSizeType =
80 LLVM::LLVMType::getArrayTy(getInt64Type(), rank);
81
82 // Create a type
83 // `!llvm<"{ `element-type`*, `element-type`*, i64,
84 // [`rank` x i64], [`rank` x i64]}">`.
85 return LLVM::LLVMType::getStructTy(
86 &getContext(),
87 {llvmPtrToElementType, llvmPtrToElementType, getInt64Type(),
88 llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
89 }
90
getVoidType()91 LLVM::LLVMType getVoidType() { return llvmVoidType; }
getPointerType()92 LLVM::LLVMType getPointerType() { return llvmPointerType; }
getInt32Type()93 LLVM::LLVMType getInt32Type() { return llvmInt32Type; }
getInt64Type()94 LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
95
96 /// Creates an LLVM global for the given `name`.
97 Value createEntryPointNameConstant(StringRef name, Location loc,
98 OpBuilder &builder);
99
100 /// Declares all needed runtime functions.
101 void declareVulkanFunctions(Location loc);
102
103 /// Checks whether the given LLVM::CallOp is a vulkan launch call op.
isVulkanLaunchCallOp(LLVM::CallOp callOp)104 bool isVulkanLaunchCallOp(LLVM::CallOp callOp) {
105 return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch &&
106 callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands);
107 }
108
109 /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call
110 /// op.
isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp)111 bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) {
112 return (callOp.callee() &&
113 callOp.callee().getValue() == kCInterfaceVulkanLaunch &&
114 callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands);
115 }
116
117 /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan
118 /// runtime calls.
119 void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp);
120
121 /// Creates call to `bindMemRef` for each memref operand.
122 void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp,
123 Value vulkanRuntime);
124
125 /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`.
126 void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp);
127
128 /// Deduces a rank and element type from the given 'ptrToMemRefDescriptor`.
129 LogicalResult deduceMemRefRankAndType(Value ptrToMemRefDescriptor,
130 uint32_t &rank, LLVM::LLVMType &type);
131
132 /// Returns a string representation from the given `type`.
stringifyType(LLVM::LLVMType type)133 StringRef stringifyType(LLVM::LLVMType type) {
134 if (type.isFloatTy())
135 return "Float";
136 if (type.isHalfTy())
137 return "Half";
138 if (type.isIntegerTy(32))
139 return "Int32";
140 if (type.isIntegerTy(16))
141 return "Int16";
142 if (type.isIntegerTy(8))
143 return "Int8";
144
145 llvm_unreachable("unsupported type");
146 }
147
148 public:
149 void runOnOperation() override;
150
151 private:
152 LLVM::LLVMType llvmFloatType;
153 LLVM::LLVMType llvmVoidType;
154 LLVM::LLVMType llvmPointerType;
155 LLVM::LLVMType llvmInt32Type;
156 LLVM::LLVMType llvmInt64Type;
157
158 // TODO: Use an associative array to support multiple vulkan launch calls.
159 std::pair<StringAttr, StringAttr> spirvAttributes;
160 /// The number of vulkan launch configuration operands, placed at the leading
161 /// positions of the operand list.
162 static constexpr unsigned kVulkanLaunchNumConfigOperands = 3;
163 };
164
165 } // anonymous namespace
166
runOnOperation()167 void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() {
168 initializeCachedTypes();
169
170 // Collect SPIR-V attributes such as `spirv_blob` and
171 // `spirv_entry_point_name`.
172 getOperation().walk([this](LLVM::CallOp op) {
173 if (isVulkanLaunchCallOp(op))
174 collectSPIRVAttributes(op);
175 });
176
177 // Convert vulkan launch call op into a sequence of Vulkan runtime calls.
178 getOperation().walk([this](LLVM::CallOp op) {
179 if (isCInterfaceVulkanLaunchCallOp(op))
180 translateVulkanLaunchCall(op);
181 });
182 }
183
collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp)184 void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes(
185 LLVM::CallOp vulkanLaunchCallOp) {
186 // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes
187 // for the given vulkan launch call.
188 auto spirvBlobAttr =
189 vulkanLaunchCallOp->getAttrOfType<StringAttr>(kSPIRVBlobAttrName);
190 if (!spirvBlobAttr) {
191 vulkanLaunchCallOp.emitError()
192 << "missing " << kSPIRVBlobAttrName << " attribute";
193 return signalPassFailure();
194 }
195
196 auto spirvEntryPointNameAttr =
197 vulkanLaunchCallOp->getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName);
198 if (!spirvEntryPointNameAttr) {
199 vulkanLaunchCallOp.emitError()
200 << "missing " << kSPIRVEntryPointAttrName << " attribute";
201 return signalPassFailure();
202 }
203
204 spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr);
205 }
206
createBindMemRefCalls(LLVM::CallOp cInterfaceVulkanLaunchCallOp,Value vulkanRuntime)207 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
208 LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) {
209 if (cInterfaceVulkanLaunchCallOp.getNumOperands() ==
210 kVulkanLaunchNumConfigOperands)
211 return;
212 OpBuilder builder(cInterfaceVulkanLaunchCallOp);
213 Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
214
215 // Create LLVM constant for the descriptor set index.
216 // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV`
217 // pass does.
218 Value descriptorSet = builder.create<LLVM::ConstantOp>(
219 loc, getInt32Type(), builder.getI32IntegerAttr(0));
220
221 for (auto en :
222 llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front(
223 kVulkanLaunchNumConfigOperands))) {
224 // Create LLVM constant for the descriptor binding index.
225 Value descriptorBinding = builder.create<LLVM::ConstantOp>(
226 loc, getInt32Type(), builder.getI32IntegerAttr(en.index()));
227
228 auto ptrToMemRefDescriptor = en.value();
229 uint32_t rank = 0;
230 LLVM::LLVMType type;
231 if (failed(deduceMemRefRankAndType(ptrToMemRefDescriptor, rank, type))) {
232 cInterfaceVulkanLaunchCallOp.emitError()
233 << "invalid memref descriptor " << ptrToMemRefDescriptor.getType();
234 return signalPassFailure();
235 }
236
237 auto symbolName =
238 llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str();
239 // Special case for fp16 type. Since it is not a supported type in C we use
240 // int16_t and bitcast the descriptor.
241 if (type.isHalfTy()) {
242 auto memRefTy =
243 getMemRefType(rank, LLVM::LLVMType::getInt16Ty(&getContext()));
244 ptrToMemRefDescriptor = builder.create<LLVM::BitcastOp>(
245 loc, memRefTy.getPointerTo(), ptrToMemRefDescriptor);
246 }
247 // Create call to `bindMemRef`.
248 builder.create<LLVM::CallOp>(
249 loc, TypeRange{getVoidType()},
250 builder.getSymbolRefAttr(
251 StringRef(symbolName.data(), symbolName.size())),
252 ValueRange{vulkanRuntime, descriptorSet, descriptorBinding,
253 ptrToMemRefDescriptor});
254 }
255 }
256
deduceMemRefRankAndType(Value ptrToMemRefDescriptor,uint32_t & rank,LLVM::LLVMType & type)257 LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType(
258 Value ptrToMemRefDescriptor, uint32_t &rank, LLVM::LLVMType &type) {
259 auto llvmPtrDescriptorTy =
260 ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMType>();
261 if (!llvmPtrDescriptorTy)
262 return failure();
263
264 auto llvmDescriptorTy = llvmPtrDescriptorTy.getPointerElementTy();
265 // template <typename Elem, size_t Rank>
266 // struct {
267 // Elem *allocated;
268 // Elem *aligned;
269 // int64_t offset;
270 // int64_t sizes[Rank]; // omitted when rank == 0
271 // int64_t strides[Rank]; // omitted when rank == 0
272 // };
273 if (!llvmDescriptorTy || !llvmDescriptorTy.isStructTy())
274 return failure();
275
276 type = llvmDescriptorTy.getStructElementType(0).getPointerElementTy();
277 if (llvmDescriptorTy.getStructNumElements() == 3) {
278 rank = 0;
279 return success();
280 }
281 rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements();
282 return success();
283 }
284
declareVulkanFunctions(Location loc)285 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
286 ModuleOp module = getOperation();
287 OpBuilder builder(module.getBody()->getTerminator());
288
289 if (!module.lookupSymbol(kSetEntryPoint)) {
290 builder.create<LLVM::LLVMFuncOp>(
291 loc, kSetEntryPoint,
292 LLVM::LLVMType::getFunctionTy(getVoidType(),
293 {getPointerType(), getPointerType()},
294 /*isVarArg=*/false));
295 }
296
297 if (!module.lookupSymbol(kSetNumWorkGroups)) {
298 builder.create<LLVM::LLVMFuncOp>(
299 loc, kSetNumWorkGroups,
300 LLVM::LLVMType::getFunctionTy(
301 getVoidType(),
302 {getPointerType(), getInt64Type(), getInt64Type(), getInt64Type()},
303 /*isVarArg=*/false));
304 }
305
306 if (!module.lookupSymbol(kSetBinaryShader)) {
307 builder.create<LLVM::LLVMFuncOp>(
308 loc, kSetBinaryShader,
309 LLVM::LLVMType::getFunctionTy(
310 getVoidType(), {getPointerType(), getPointerType(), getInt32Type()},
311 /*isVarArg=*/false));
312 }
313
314 if (!module.lookupSymbol(kRunOnVulkan)) {
315 builder.create<LLVM::LLVMFuncOp>(
316 loc, kRunOnVulkan,
317 LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
318 /*isVarArg=*/false));
319 }
320
321 for (unsigned i = 1; i <= 3; i++) {
322 for (LLVM::LLVMType type : {LLVM::LLVMType::getFloatTy(&getContext()),
323 LLVM::LLVMType::getInt32Ty(&getContext()),
324 LLVM::LLVMType::getInt16Ty(&getContext()),
325 LLVM::LLVMType::getInt8Ty(&getContext()),
326 LLVM::LLVMType::getHalfTy(&getContext())}) {
327 std::string fnName = "bindMemRef" + std::to_string(i) + "D" +
328 std::string(stringifyType(type));
329 if (type.isHalfTy())
330 type = LLVM::LLVMType::getInt16Ty(&getContext());
331 if (!module.lookupSymbol(fnName)) {
332 auto fnType = LLVM::LLVMType::getFunctionTy(
333 getVoidType(),
334 {getPointerType(), getInt32Type(), getInt32Type(),
335 getMemRefType(i, type).getPointerTo()},
336 /*isVarArg=*/false);
337 builder.create<LLVM::LLVMFuncOp>(loc, fnName, fnType);
338 }
339 }
340 }
341
342 if (!module.lookupSymbol(kInitVulkan)) {
343 builder.create<LLVM::LLVMFuncOp>(
344 loc, kInitVulkan,
345 LLVM::LLVMType::getFunctionTy(getPointerType(), {},
346 /*isVarArg=*/false));
347 }
348
349 if (!module.lookupSymbol(kDeinitVulkan)) {
350 builder.create<LLVM::LLVMFuncOp>(
351 loc, kDeinitVulkan,
352 LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
353 /*isVarArg=*/false));
354 }
355 }
356
createEntryPointNameConstant(StringRef name,Location loc,OpBuilder & builder)357 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
358 StringRef name, Location loc, OpBuilder &builder) {
359 SmallString<16> shaderName(name.begin(), name.end());
360 // Append `\0` to follow C style string given that LLVM::createGlobalString()
361 // won't handle this directly for us.
362 shaderName.push_back('\0');
363
364 std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
365 return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
366 shaderName, LLVM::Linkage::Internal);
367 }
368
translateVulkanLaunchCall(LLVM::CallOp cInterfaceVulkanLaunchCallOp)369 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
370 LLVM::CallOp cInterfaceVulkanLaunchCallOp) {
371 OpBuilder builder(cInterfaceVulkanLaunchCallOp);
372 Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
373 // Create call to `initVulkan`.
374 auto initVulkanCall = builder.create<LLVM::CallOp>(
375 loc, TypeRange{getPointerType()}, builder.getSymbolRefAttr(kInitVulkan),
376 ValueRange{});
377 // The result of `initVulkan` function is a pointer to Vulkan runtime, we
378 // need to pass that pointer to each Vulkan runtime call.
379 auto vulkanRuntime = initVulkanCall.getResult(0);
380
381 // Create LLVM global with SPIR-V binary data, so we can pass a pointer with
382 // that data to runtime call.
383 Value ptrToSPIRVBinary = LLVM::createGlobalString(
384 loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(),
385 LLVM::Linkage::Internal);
386
387 // Create LLVM constant for the size of SPIR-V binary shader.
388 Value binarySize = builder.create<LLVM::ConstantOp>(
389 loc, getInt32Type(),
390 builder.getI32IntegerAttr(spirvAttributes.first.getValue().size()));
391
392 // Create call to `bindMemRef` for each memref operand.
393 createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime);
394
395 // Create call to `setBinaryShader` runtime function with the given pointer to
396 // SPIR-V binary and binary size.
397 builder.create<LLVM::CallOp>(
398 loc, TypeRange{getVoidType()}, builder.getSymbolRefAttr(kSetBinaryShader),
399 ValueRange{vulkanRuntime, ptrToSPIRVBinary, binarySize});
400 // Create LLVM global with entry point name.
401 Value entryPointName = createEntryPointNameConstant(
402 spirvAttributes.second.getValue(), loc, builder);
403 // Create call to `setEntryPoint` runtime function with the given pointer to
404 // entry point name.
405 builder.create<LLVM::CallOp>(loc, TypeRange{getVoidType()},
406 builder.getSymbolRefAttr(kSetEntryPoint),
407 ValueRange{vulkanRuntime, entryPointName});
408
409 // Create number of local workgroup for each dimension.
410 builder.create<LLVM::CallOp>(
411 loc, TypeRange{getVoidType()},
412 builder.getSymbolRefAttr(kSetNumWorkGroups),
413 ValueRange{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0),
414 cInterfaceVulkanLaunchCallOp.getOperand(1),
415 cInterfaceVulkanLaunchCallOp.getOperand(2)});
416
417 // Create call to `runOnVulkan` runtime function.
418 builder.create<LLVM::CallOp>(loc, TypeRange{getVoidType()},
419 builder.getSymbolRefAttr(kRunOnVulkan),
420 ValueRange{vulkanRuntime});
421
422 // Create call to 'deinitVulkan' runtime function.
423 builder.create<LLVM::CallOp>(loc, TypeRange{getVoidType()},
424 builder.getSymbolRefAttr(kDeinitVulkan),
425 ValueRange{vulkanRuntime});
426
427 // Declare runtime functions.
428 declareVulkanFunctions(loc);
429
430 cInterfaceVulkanLaunchCallOp.erase();
431 }
432
433 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createConvertVulkanLaunchFuncToVulkanCallsPass()434 mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() {
435 return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>();
436 }
437