1 //===- VulkanRuntime.cpp - MLIR Vulkan runtime ------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file declares Vulkan runtime API. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef VULKAN_RUNTIME_H 14 #define VULKAN_RUNTIME_H 15 16 #include "mlir/Support/LogicalResult.h" 17 18 #include <unordered_map> 19 #include <vector> 20 #include <vulkan/vulkan.h> 21 22 using namespace mlir; 23 24 using DescriptorSetIndex = uint32_t; 25 using BindingIndex = uint32_t; 26 27 /// Struct containing information regarding to a device memory buffer. 28 struct VulkanDeviceMemoryBuffer { 29 BindingIndex bindingIndex{0}; 30 VkDescriptorType descriptorType{VK_DESCRIPTOR_TYPE_MAX_ENUM}; 31 VkDescriptorBufferInfo bufferInfo{}; 32 VkBuffer hostBuffer{VK_NULL_HANDLE}; 33 VkDeviceMemory hostMemory{VK_NULL_HANDLE}; 34 VkBuffer deviceBuffer{VK_NULL_HANDLE}; 35 VkDeviceMemory deviceMemory{VK_NULL_HANDLE}; 36 uint32_t bufferSize{0}; 37 }; 38 39 /// Struct containing information regarding to a host memory buffer. 40 struct VulkanHostMemoryBuffer { 41 /// Pointer to a host memory. 42 void *ptr{nullptr}; 43 /// Size of a host memory in bytes. 44 uint32_t size{0}; 45 }; 46 47 /// Struct containing the number of local workgroups to dispatch for each 48 /// dimension. 49 struct NumWorkGroups { 50 uint32_t x{1}; 51 uint32_t y{1}; 52 uint32_t z{1}; 53 }; 54 55 /// Struct containing information regarding a descriptor set. 56 struct DescriptorSetInfo { 57 /// Index of a descriptor set in descriptor sets. 58 DescriptorSetIndex descriptorSet{0}; 59 /// Number of descriptors in a set. 60 uint32_t descriptorSize{0}; 61 /// Type of a descriptor set. 62 VkDescriptorType descriptorType{VK_DESCRIPTOR_TYPE_MAX_ENUM}; 63 }; 64 65 /// VulkanHostMemoryBuffer mapped into a descriptor set and a binding. 66 using ResourceData = std::unordered_map< 67 DescriptorSetIndex, 68 std::unordered_map<BindingIndex, VulkanHostMemoryBuffer>>; 69 70 /// SPIR-V storage classes. 71 /// Note that this duplicates spirv::StorageClass but it keeps the Vulkan 72 /// runtime library detached from SPIR-V dialect, so we can avoid pick up lots 73 /// of dependencies. 74 enum class SPIRVStorageClass { 75 Uniform = 2, 76 StorageBuffer = 12, 77 }; 78 79 /// StorageClass mapped into a descriptor set and a binding. 80 using ResourceStorageClassBindingMap = 81 std::unordered_map<DescriptorSetIndex, 82 std::unordered_map<BindingIndex, SPIRVStorageClass>>; 83 84 /// Vulkan runtime. 85 /// The purpose of this class is to run SPIR-V compute shader on Vulkan 86 /// device. 87 /// Before the run, user must provide and set resource data with descriptors, 88 /// SPIR-V shader, number of work groups and entry point. After the creation of 89 /// VulkanRuntime, special methods must be called in the following 90 /// sequence: initRuntime(), run(), updateHostMemoryBuffers(), destroy(); 91 /// each method in the sequence returns success or failure depends on the Vulkan 92 /// result code. 93 class VulkanRuntime { 94 public: 95 explicit VulkanRuntime() = default; 96 VulkanRuntime(const VulkanRuntime &) = delete; 97 VulkanRuntime &operator=(const VulkanRuntime &) = delete; 98 99 /// Sets needed data for Vulkan runtime. 100 void setResourceData(const ResourceData &resData); 101 void setResourceData(const DescriptorSetIndex desIndex, 102 const BindingIndex bindIndex, 103 const VulkanHostMemoryBuffer &hostMemBuffer); 104 void setShaderModule(uint8_t *shader, uint32_t size); 105 void setNumWorkGroups(const NumWorkGroups &numberWorkGroups); 106 void setResourceStorageClassBindingMap( 107 const ResourceStorageClassBindingMap &stClassData); 108 void setEntryPoint(const char *entryPointName); 109 110 /// Runtime initialization. 111 LogicalResult initRuntime(); 112 113 /// Runs runtime. 114 LogicalResult run(); 115 116 /// Updates host memory buffers. 117 LogicalResult updateHostMemoryBuffers(); 118 119 /// Destroys all created vulkan objects and resources. 120 LogicalResult destroy(); 121 122 private: 123 //===--------------------------------------------------------------------===// 124 // Pipeline creation methods. 125 //===--------------------------------------------------------------------===// 126 127 LogicalResult createInstance(); 128 LogicalResult createDevice(); 129 LogicalResult getBestComputeQueue(); 130 LogicalResult createMemoryBuffers(); 131 LogicalResult createShaderModule(); 132 void initDescriptorSetLayoutBindingMap(); 133 LogicalResult createDescriptorSetLayout(); 134 LogicalResult createPipelineLayout(); 135 LogicalResult createComputePipeline(); 136 LogicalResult createDescriptorPool(); 137 LogicalResult allocateDescriptorSets(); 138 LogicalResult setWriteDescriptors(); 139 LogicalResult createCommandPool(); 140 LogicalResult createQueryPool(); 141 LogicalResult createComputeCommandBuffer(); 142 LogicalResult submitCommandBuffersToQueue(); 143 // Copy resources from host (staging buffer) to device buffer or from device 144 // buffer to host buffer. 145 LogicalResult copyResource(bool deviceToHost); 146 147 //===--------------------------------------------------------------------===// 148 // Helper methods. 149 //===--------------------------------------------------------------------===// 150 151 /// Maps storage class to a descriptor type. 152 LogicalResult 153 mapStorageClassToDescriptorType(SPIRVStorageClass storageClass, 154 VkDescriptorType &descriptorType); 155 156 /// Maps storage class to buffer usage flags. 157 LogicalResult 158 mapStorageClassToBufferUsageFlag(SPIRVStorageClass storageClass, 159 VkBufferUsageFlagBits &bufferUsage); 160 161 LogicalResult countDeviceMemorySize(); 162 163 //===--------------------------------------------------------------------===// 164 // Vulkan objects. 165 //===--------------------------------------------------------------------===// 166 167 VkInstance instance{VK_NULL_HANDLE}; 168 VkPhysicalDevice physicalDevice{VK_NULL_HANDLE}; 169 VkDevice device{VK_NULL_HANDLE}; 170 VkQueue queue{VK_NULL_HANDLE}; 171 172 /// Specifies VulkanDeviceMemoryBuffers divided into sets. 173 std::unordered_map<DescriptorSetIndex, std::vector<VulkanDeviceMemoryBuffer>> 174 deviceMemoryBufferMap; 175 176 /// Specifies shader module. 177 VkShaderModule shaderModule{VK_NULL_HANDLE}; 178 179 /// Specifies layout bindings. 180 std::unordered_map<DescriptorSetIndex, 181 std::vector<VkDescriptorSetLayoutBinding>> 182 descriptorSetLayoutBindingMap; 183 184 /// Specifies layouts of descriptor sets. 185 std::vector<VkDescriptorSetLayout> descriptorSetLayouts; 186 VkPipelineLayout pipelineLayout{VK_NULL_HANDLE}; 187 188 /// Specifies descriptor sets. 189 std::vector<VkDescriptorSet> descriptorSets; 190 191 /// Specifies a pool of descriptor set info, each descriptor set must have 192 /// information such as type, index and amount of bindings. 193 std::vector<DescriptorSetInfo> descriptorSetInfoPool; 194 VkDescriptorPool descriptorPool{VK_NULL_HANDLE}; 195 196 /// Timestamp query. 197 VkQueryPool queryPool{VK_NULL_HANDLE}; 198 // Number of nonoseconds for timestamp to increase 1 199 float timestampPeriod{0.f}; 200 201 /// Computation pipeline. 202 VkPipeline pipeline{VK_NULL_HANDLE}; 203 VkCommandPool commandPool{VK_NULL_HANDLE}; 204 std::vector<VkCommandBuffer> commandBuffers; 205 206 //===--------------------------------------------------------------------===// 207 // Vulkan memory context. 208 //===--------------------------------------------------------------------===// 209 210 uint32_t queueFamilyIndex{0}; 211 VkQueueFamilyProperties queueFamilyProperties{}; 212 uint32_t hostMemoryTypeIndex{VK_MAX_MEMORY_TYPES}; 213 uint32_t deviceMemoryTypeIndex{VK_MAX_MEMORY_TYPES}; 214 VkDeviceSize memorySize{0}; 215 216 //===--------------------------------------------------------------------===// 217 // Vulkan execution context. 218 //===--------------------------------------------------------------------===// 219 220 NumWorkGroups numWorkGroups; 221 const char *entryPoint{nullptr}; 222 uint8_t *binary{nullptr}; 223 uint32_t binarySize{0}; 224 225 //===--------------------------------------------------------------------===// 226 // Vulkan resource data and storage classes. 227 //===--------------------------------------------------------------------===// 228 229 ResourceData resourceData; 230 ResourceStorageClassBindingMap resourceStorageClassData; 231 }; 232 #endif 233