1 //===- vulkan-runtime-wrappers.cpp - MLIR Vulkan runner wrapper library ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implements C runtime wrappers around the VulkanRuntime.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <iostream>
14 #include <mutex>
15 #include <numeric>
16 
17 #include "VulkanRuntime.h"
18 
19 // Explicitly export entry points to the vulkan-runtime-wrapper.
20 #define VULKAN_WRAPPER_SYMBOL_EXPORT __attribute__((visibility("default")))
21 
22 namespace {
23 
24 class VulkanRuntimeManager {
25 public:
26   VulkanRuntimeManager() = default;
27   VulkanRuntimeManager(const VulkanRuntimeManager &) = delete;
28   VulkanRuntimeManager operator=(const VulkanRuntimeManager &) = delete;
29   ~VulkanRuntimeManager() = default;
30 
setResourceData(DescriptorSetIndex setIndex,BindingIndex bindIndex,const VulkanHostMemoryBuffer & memBuffer)31   void setResourceData(DescriptorSetIndex setIndex, BindingIndex bindIndex,
32                        const VulkanHostMemoryBuffer &memBuffer) {
33     std::lock_guard<std::mutex> lock(mutex);
34     vulkanRuntime.setResourceData(setIndex, bindIndex, memBuffer);
35   }
36 
setEntryPoint(const char * entryPoint)37   void setEntryPoint(const char *entryPoint) {
38     std::lock_guard<std::mutex> lock(mutex);
39     vulkanRuntime.setEntryPoint(entryPoint);
40   }
41 
setNumWorkGroups(NumWorkGroups numWorkGroups)42   void setNumWorkGroups(NumWorkGroups numWorkGroups) {
43     std::lock_guard<std::mutex> lock(mutex);
44     vulkanRuntime.setNumWorkGroups(numWorkGroups);
45   }
46 
setShaderModule(uint8_t * shader,uint32_t size)47   void setShaderModule(uint8_t *shader, uint32_t size) {
48     std::lock_guard<std::mutex> lock(mutex);
49     vulkanRuntime.setShaderModule(shader, size);
50   }
51 
runOnVulkan()52   void runOnVulkan() {
53     std::lock_guard<std::mutex> lock(mutex);
54     if (failed(vulkanRuntime.initRuntime()) || failed(vulkanRuntime.run()) ||
55         failed(vulkanRuntime.updateHostMemoryBuffers()) ||
56         failed(vulkanRuntime.destroy())) {
57       std::cerr << "runOnVulkan failed";
58     }
59   }
60 
61 private:
62   VulkanRuntime vulkanRuntime;
63   std::mutex mutex;
64 };
65 
66 } // namespace
67 
68 template <typename T, int N> struct MemRefDescriptor {
69   T *allocated;
70   T *aligned;
71   int64_t offset;
72   int64_t sizes[N];
73   int64_t strides[N];
74 };
75 
76 template <typename T, uint32_t S>
bindMemRef(void * vkRuntimeManager,DescriptorSetIndex setIndex,BindingIndex bindIndex,MemRefDescriptor<T,S> * ptr)77 void bindMemRef(void *vkRuntimeManager, DescriptorSetIndex setIndex,
78                 BindingIndex bindIndex, MemRefDescriptor<T, S> *ptr) {
79   uint32_t size = sizeof(T);
80   for (unsigned i = 0; i < S; i++)
81     size *= ptr->sizes[i];
82   VulkanHostMemoryBuffer memBuffer{ptr->allocated, size};
83   reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
84       ->setResourceData(setIndex, bindIndex, memBuffer);
85 }
86 
87 extern "C" {
88 /// Initializes `VulkanRuntimeManager` and returns a pointer to it.
initVulkan()89 VULKAN_WRAPPER_SYMBOL_EXPORT void *initVulkan() {
90   return new VulkanRuntimeManager();
91 }
92 
93 /// Deinitializes `VulkanRuntimeManager` by the given pointer.
deinitVulkan(void * vkRuntimeManager)94 VULKAN_WRAPPER_SYMBOL_EXPORT void deinitVulkan(void *vkRuntimeManager) {
95   delete reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager);
96 }
97 
runOnVulkan(void * vkRuntimeManager)98 VULKAN_WRAPPER_SYMBOL_EXPORT void runOnVulkan(void *vkRuntimeManager) {
99   reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)->runOnVulkan();
100 }
101 
setEntryPoint(void * vkRuntimeManager,const char * entryPoint)102 VULKAN_WRAPPER_SYMBOL_EXPORT void setEntryPoint(void *vkRuntimeManager,
103                                                 const char *entryPoint) {
104   reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
105       ->setEntryPoint(entryPoint);
106 }
107 
108 VULKAN_WRAPPER_SYMBOL_EXPORT void
setNumWorkGroups(void * vkRuntimeManager,uint32_t x,uint32_t y,uint32_t z)109 setNumWorkGroups(void *vkRuntimeManager, uint32_t x, uint32_t y, uint32_t z) {
110   reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
111       ->setNumWorkGroups({x, y, z});
112 }
113 
114 VULKAN_WRAPPER_SYMBOL_EXPORT void
setBinaryShader(void * vkRuntimeManager,uint8_t * shader,uint32_t size)115 setBinaryShader(void *vkRuntimeManager, uint8_t *shader, uint32_t size) {
116   reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
117       ->setShaderModule(shader, size);
118 }
119 
120 /// Binds the given memref to the given descriptor set and descriptor
121 /// index.
122 #define DECLARE_BIND_MEMREF(size, type, typeName)                              \
123   VULKAN_WRAPPER_SYMBOL_EXPORT void bindMemRef##size##D##typeName(             \
124       void *vkRuntimeManager, DescriptorSetIndex setIndex,                     \
125       BindingIndex bindIndex, MemRefDescriptor<type, size> *ptr) {             \
126     bindMemRef<type, size>(vkRuntimeManager, setIndex, bindIndex, ptr);        \
127   }
128 
129 DECLARE_BIND_MEMREF(1, float, Float)
130 DECLARE_BIND_MEMREF(2, float, Float)
131 DECLARE_BIND_MEMREF(3, float, Float)
132 DECLARE_BIND_MEMREF(1, int32_t, Int32)
133 DECLARE_BIND_MEMREF(2, int32_t, Int32)
134 DECLARE_BIND_MEMREF(3, int32_t, Int32)
135 DECLARE_BIND_MEMREF(1, int16_t, Int16)
136 DECLARE_BIND_MEMREF(2, int16_t, Int16)
137 DECLARE_BIND_MEMREF(3, int16_t, Int16)
138 DECLARE_BIND_MEMREF(1, int8_t, Int8)
139 DECLARE_BIND_MEMREF(2, int8_t, Int8)
140 DECLARE_BIND_MEMREF(3, int8_t, Int8)
141 DECLARE_BIND_MEMREF(1, int16_t, Half)
142 DECLARE_BIND_MEMREF(2, int16_t, Half)
143 DECLARE_BIND_MEMREF(3, int16_t, Half)
144 
145 /// Fills the given 1D float memref with the given float value.
146 VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource1DFloat(MemRefDescriptor<float,1> * ptr,float value)147 _mlir_ciface_fillResource1DFloat(MemRefDescriptor<float, 1> *ptr, // NOLINT
148                                  float value) {
149   std::fill_n(ptr->allocated, ptr->sizes[0], value);
150 }
151 
152 /// Fills the given 2D float memref with the given float value.
153 VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource2DFloat(MemRefDescriptor<float,2> * ptr,float value)154 _mlir_ciface_fillResource2DFloat(MemRefDescriptor<float, 2> *ptr, // NOLINT
155                                  float value) {
156   std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
157 }
158 
159 /// Fills the given 3D float memref with the given float value.
160 VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource3DFloat(MemRefDescriptor<float,3> * ptr,float value)161 _mlir_ciface_fillResource3DFloat(MemRefDescriptor<float, 3> *ptr, // NOLINT
162                                  float value) {
163   std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
164               value);
165 }
166 
167 /// Fills the given 1D int memref with the given int value.
168 VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource1DInt(MemRefDescriptor<int32_t,1> * ptr,int32_t value)169 _mlir_ciface_fillResource1DInt(MemRefDescriptor<int32_t, 1> *ptr, // NOLINT
170                                int32_t value) {
171   std::fill_n(ptr->allocated, ptr->sizes[0], value);
172 }
173 
174 /// Fills the given 2D int memref with the given int value.
175 VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource2DInt(MemRefDescriptor<int32_t,2> * ptr,int32_t value)176 _mlir_ciface_fillResource2DInt(MemRefDescriptor<int32_t, 2> *ptr, // NOLINT
177                                int32_t value) {
178   std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
179 }
180 
181 /// Fills the given 3D int memref with the given int value.
182 VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource3DInt(MemRefDescriptor<int32_t,3> * ptr,int32_t value)183 _mlir_ciface_fillResource3DInt(MemRefDescriptor<int32_t, 3> *ptr, // NOLINT
184                                int32_t value) {
185   std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
186               value);
187 }
188 
189 /// Fills the given 1D int memref with the given int8 value.
190 VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource1DInt8(MemRefDescriptor<int8_t,1> * ptr,int8_t value)191 _mlir_ciface_fillResource1DInt8(MemRefDescriptor<int8_t, 1> *ptr, // NOLINT
192                                 int8_t value) {
193   std::fill_n(ptr->allocated, ptr->sizes[0], value);
194 }
195 
196 /// Fills the given 2D int memref with the given int8 value.
197 VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource2DInt8(MemRefDescriptor<int8_t,2> * ptr,int8_t value)198 _mlir_ciface_fillResource2DInt8(MemRefDescriptor<int8_t, 2> *ptr, // NOLINT
199                                 int8_t value) {
200   std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
201 }
202 
203 /// Fills the given 3D int memref with the given int8 value.
204 VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource3DInt8(MemRefDescriptor<int8_t,3> * ptr,int8_t value)205 _mlir_ciface_fillResource3DInt8(MemRefDescriptor<int8_t, 3> *ptr, // NOLINT
206                                 int8_t value) {
207   std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
208               value);
209 }
210 }
211