1 //===- VulkanRuntime.cpp - MLIR Vulkan runtime ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares Vulkan runtime API.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef VULKAN_RUNTIME_H
14 #define VULKAN_RUNTIME_H
15 
16 #include "mlir/Support/LogicalResult.h"
17 
18 #include <unordered_map>
19 #include <vector>
20 #include <vulkan/vulkan.h>
21 
22 using namespace mlir;
23 
24 using DescriptorSetIndex = uint32_t;
25 using BindingIndex = uint32_t;
26 
27 /// Struct containing information regarding to a device memory buffer.
28 struct VulkanDeviceMemoryBuffer {
29   BindingIndex bindingIndex{0};
30   VkDescriptorType descriptorType{VK_DESCRIPTOR_TYPE_MAX_ENUM};
31   VkDescriptorBufferInfo bufferInfo{};
32   VkBuffer hostBuffer{VK_NULL_HANDLE};
33   VkDeviceMemory hostMemory{VK_NULL_HANDLE};
34   VkBuffer deviceBuffer{VK_NULL_HANDLE};
35   VkDeviceMemory deviceMemory{VK_NULL_HANDLE};
36   uint32_t bufferSize{0};
37 };
38 
39 /// Struct containing information regarding to a host memory buffer.
40 struct VulkanHostMemoryBuffer {
41   /// Pointer to a host memory.
42   void *ptr{nullptr};
43   /// Size of a host memory in bytes.
44   uint32_t size{0};
45 };
46 
47 /// Struct containing the number of local workgroups to dispatch for each
48 /// dimension.
49 struct NumWorkGroups {
50   uint32_t x{1};
51   uint32_t y{1};
52   uint32_t z{1};
53 };
54 
55 /// Struct containing information regarding a descriptor set.
56 struct DescriptorSetInfo {
57   /// Index of a descriptor set in descriptor sets.
58   DescriptorSetIndex descriptorSet{0};
59   /// Number of descriptors in a set.
60   uint32_t descriptorSize{0};
61   /// Type of a descriptor set.
62   VkDescriptorType descriptorType{VK_DESCRIPTOR_TYPE_MAX_ENUM};
63 };
64 
65 /// VulkanHostMemoryBuffer mapped into a descriptor set and a binding.
66 using ResourceData = std::unordered_map<
67     DescriptorSetIndex,
68     std::unordered_map<BindingIndex, VulkanHostMemoryBuffer>>;
69 
70 /// SPIR-V storage classes.
71 /// Note that this duplicates spirv::StorageClass but it keeps the Vulkan
72 /// runtime library detached from SPIR-V dialect, so we can avoid pick up lots
73 /// of dependencies.
74 enum class SPIRVStorageClass {
75   Uniform = 2,
76   StorageBuffer = 12,
77 };
78 
79 /// StorageClass mapped into a descriptor set and a binding.
80 using ResourceStorageClassBindingMap =
81     std::unordered_map<DescriptorSetIndex,
82                        std::unordered_map<BindingIndex, SPIRVStorageClass>>;
83 
84 /// Vulkan runtime.
85 /// The purpose of this class is to run SPIR-V compute shader on Vulkan
86 /// device.
87 /// Before the run, user must provide and set resource data with descriptors,
88 /// SPIR-V shader, number of work groups and entry point. After the creation of
89 /// VulkanRuntime, special methods must be called in the following
90 /// sequence: initRuntime(), run(), updateHostMemoryBuffers(), destroy();
91 /// each method in the sequence returns success or failure depends on the Vulkan
92 /// result code.
93 class VulkanRuntime {
94 public:
95   explicit VulkanRuntime() = default;
96   VulkanRuntime(const VulkanRuntime &) = delete;
97   VulkanRuntime &operator=(const VulkanRuntime &) = delete;
98 
99   /// Sets needed data for Vulkan runtime.
100   void setResourceData(const ResourceData &resData);
101   void setResourceData(const DescriptorSetIndex desIndex,
102                        const BindingIndex bindIndex,
103                        const VulkanHostMemoryBuffer &hostMemBuffer);
104   void setShaderModule(uint8_t *shader, uint32_t size);
105   void setNumWorkGroups(const NumWorkGroups &numberWorkGroups);
106   void setResourceStorageClassBindingMap(
107       const ResourceStorageClassBindingMap &stClassData);
108   void setEntryPoint(const char *entryPointName);
109 
110   /// Runtime initialization.
111   LogicalResult initRuntime();
112 
113   /// Runs runtime.
114   LogicalResult run();
115 
116   /// Updates host memory buffers.
117   LogicalResult updateHostMemoryBuffers();
118 
119   /// Destroys all created vulkan objects and resources.
120   LogicalResult destroy();
121 
122 private:
123   //===--------------------------------------------------------------------===//
124   // Pipeline creation methods.
125   //===--------------------------------------------------------------------===//
126 
127   LogicalResult createInstance();
128   LogicalResult createDevice();
129   LogicalResult getBestComputeQueue();
130   LogicalResult createMemoryBuffers();
131   LogicalResult createShaderModule();
132   void initDescriptorSetLayoutBindingMap();
133   LogicalResult createDescriptorSetLayout();
134   LogicalResult createPipelineLayout();
135   LogicalResult createComputePipeline();
136   LogicalResult createDescriptorPool();
137   LogicalResult allocateDescriptorSets();
138   LogicalResult setWriteDescriptors();
139   LogicalResult createCommandPool();
140   LogicalResult createQueryPool();
141   LogicalResult createComputeCommandBuffer();
142   LogicalResult submitCommandBuffersToQueue();
143   // Copy resources from host (staging buffer) to device buffer or from device
144   // buffer to host buffer.
145   LogicalResult copyResource(bool deviceToHost);
146 
147   //===--------------------------------------------------------------------===//
148   // Helper methods.
149   //===--------------------------------------------------------------------===//
150 
151   /// Maps storage class to a descriptor type.
152   LogicalResult
153   mapStorageClassToDescriptorType(SPIRVStorageClass storageClass,
154                                   VkDescriptorType &descriptorType);
155 
156   /// Maps storage class to buffer usage flags.
157   LogicalResult
158   mapStorageClassToBufferUsageFlag(SPIRVStorageClass storageClass,
159                                    VkBufferUsageFlagBits &bufferUsage);
160 
161   LogicalResult countDeviceMemorySize();
162 
163   //===--------------------------------------------------------------------===//
164   // Vulkan objects.
165   //===--------------------------------------------------------------------===//
166 
167   VkInstance instance{VK_NULL_HANDLE};
168   VkPhysicalDevice physicalDevice{VK_NULL_HANDLE};
169   VkDevice device{VK_NULL_HANDLE};
170   VkQueue queue{VK_NULL_HANDLE};
171 
172   /// Specifies VulkanDeviceMemoryBuffers divided into sets.
173   std::unordered_map<DescriptorSetIndex, std::vector<VulkanDeviceMemoryBuffer>>
174       deviceMemoryBufferMap;
175 
176   /// Specifies shader module.
177   VkShaderModule shaderModule{VK_NULL_HANDLE};
178 
179   /// Specifies layout bindings.
180   std::unordered_map<DescriptorSetIndex,
181                      std::vector<VkDescriptorSetLayoutBinding>>
182       descriptorSetLayoutBindingMap;
183 
184   /// Specifies layouts of descriptor sets.
185   std::vector<VkDescriptorSetLayout> descriptorSetLayouts;
186   VkPipelineLayout pipelineLayout{VK_NULL_HANDLE};
187 
188   /// Specifies descriptor sets.
189   std::vector<VkDescriptorSet> descriptorSets;
190 
191   /// Specifies a pool of descriptor set info, each descriptor set must have
192   /// information such as type, index and amount of bindings.
193   std::vector<DescriptorSetInfo> descriptorSetInfoPool;
194   VkDescriptorPool descriptorPool{VK_NULL_HANDLE};
195 
196   /// Timestamp query.
197   VkQueryPool queryPool{VK_NULL_HANDLE};
198   // Number of nonoseconds for timestamp to increase 1
199   float timestampPeriod{0.f};
200 
201   /// Computation pipeline.
202   VkPipeline pipeline{VK_NULL_HANDLE};
203   VkCommandPool commandPool{VK_NULL_HANDLE};
204   std::vector<VkCommandBuffer> commandBuffers;
205 
206   //===--------------------------------------------------------------------===//
207   // Vulkan memory context.
208   //===--------------------------------------------------------------------===//
209 
210   uint32_t queueFamilyIndex{0};
211   VkQueueFamilyProperties queueFamilyProperties{};
212   uint32_t hostMemoryTypeIndex{VK_MAX_MEMORY_TYPES};
213   uint32_t deviceMemoryTypeIndex{VK_MAX_MEMORY_TYPES};
214   VkDeviceSize memorySize{0};
215 
216   //===--------------------------------------------------------------------===//
217   // Vulkan execution context.
218   //===--------------------------------------------------------------------===//
219 
220   NumWorkGroups numWorkGroups;
221   const char *entryPoint{nullptr};
222   uint8_t *binary{nullptr};
223   uint32_t binarySize{0};
224 
225   //===--------------------------------------------------------------------===//
226   // Vulkan resource data and storage classes.
227   //===--------------------------------------------------------------------===//
228 
229   ResourceData resourceData;
230   ResourceStorageClassBindingMap resourceStorageClassData;
231 };
232 #endif
233