1 //===- VulkanRuntime.cpp - MLIR Vulkan runtime ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides a library for running a module on a Vulkan device.
10 // Implements a Vulkan runtime.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "VulkanRuntime.h"
15
16 #include <chrono>
17 #include <cstring>
18 // TODO: It's generally bad to access stdout/stderr in a library.
19 // Figure out a better way for error reporting.
20 #include <iomanip>
21 #include <iostream>
22
emitVulkanError(const char * api,VkResult error)23 inline void emitVulkanError(const char *api, VkResult error) {
24 std::cerr << " failed with error code " << error << " when executing " << api;
25 }
26
27 #define RETURN_ON_VULKAN_ERROR(result, api) \
28 if ((result) != VK_SUCCESS) { \
29 emitVulkanError(api, (result)); \
30 return failure(); \
31 }
32
33 using namespace mlir;
34
setNumWorkGroups(const NumWorkGroups & numberWorkGroups)35 void VulkanRuntime::setNumWorkGroups(const NumWorkGroups &numberWorkGroups) {
36 numWorkGroups = numberWorkGroups;
37 }
38
setResourceStorageClassBindingMap(const ResourceStorageClassBindingMap & stClassData)39 void VulkanRuntime::setResourceStorageClassBindingMap(
40 const ResourceStorageClassBindingMap &stClassData) {
41 resourceStorageClassData = stClassData;
42 }
43
setResourceData(const DescriptorSetIndex desIndex,const BindingIndex bindIndex,const VulkanHostMemoryBuffer & hostMemBuffer)44 void VulkanRuntime::setResourceData(
45 const DescriptorSetIndex desIndex, const BindingIndex bindIndex,
46 const VulkanHostMemoryBuffer &hostMemBuffer) {
47 resourceData[desIndex][bindIndex] = hostMemBuffer;
48 resourceStorageClassData[desIndex][bindIndex] =
49 SPIRVStorageClass::StorageBuffer;
50 }
51
setEntryPoint(const char * entryPointName)52 void VulkanRuntime::setEntryPoint(const char *entryPointName) {
53 entryPoint = entryPointName;
54 }
55
setResourceData(const ResourceData & resData)56 void VulkanRuntime::setResourceData(const ResourceData &resData) {
57 resourceData = resData;
58 }
59
setShaderModule(uint8_t * shader,uint32_t size)60 void VulkanRuntime::setShaderModule(uint8_t *shader, uint32_t size) {
61 binary = shader;
62 binarySize = size;
63 }
64
mapStorageClassToDescriptorType(SPIRVStorageClass storageClass,VkDescriptorType & descriptorType)65 LogicalResult VulkanRuntime::mapStorageClassToDescriptorType(
66 SPIRVStorageClass storageClass, VkDescriptorType &descriptorType) {
67 switch (storageClass) {
68 case SPIRVStorageClass::StorageBuffer:
69 descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
70 break;
71 case SPIRVStorageClass::Uniform:
72 descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
73 break;
74 }
75 return success();
76 }
77
mapStorageClassToBufferUsageFlag(SPIRVStorageClass storageClass,VkBufferUsageFlagBits & bufferUsage)78 LogicalResult VulkanRuntime::mapStorageClassToBufferUsageFlag(
79 SPIRVStorageClass storageClass, VkBufferUsageFlagBits &bufferUsage) {
80 switch (storageClass) {
81 case SPIRVStorageClass::StorageBuffer:
82 bufferUsage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
83 break;
84 case SPIRVStorageClass::Uniform:
85 bufferUsage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT;
86 break;
87 }
88 return success();
89 }
90
countDeviceMemorySize()91 LogicalResult VulkanRuntime::countDeviceMemorySize() {
92 for (const auto &resourceDataMapPair : resourceData) {
93 const auto &resourceDataMap = resourceDataMapPair.second;
94 for (const auto &resourceDataBindingPair : resourceDataMap) {
95 if (resourceDataBindingPair.second.size) {
96 memorySize += resourceDataBindingPair.second.size;
97 } else {
98 std::cerr << "expected buffer size greater than zero for resource data";
99 return failure();
100 }
101 }
102 }
103 return success();
104 }
105
initRuntime()106 LogicalResult VulkanRuntime::initRuntime() {
107 if (!resourceData.size()) {
108 std::cerr << "Vulkan runtime needs at least one resource";
109 return failure();
110 }
111 if (!binarySize || !binary) {
112 std::cerr << "binary shader size must be greater than zero";
113 return failure();
114 }
115 if (failed(countDeviceMemorySize())) {
116 return failure();
117 }
118 return success();
119 }
120
destroy()121 LogicalResult VulkanRuntime::destroy() {
122 // According to Vulkan spec:
123 // "To ensure that no work is active on the device, vkDeviceWaitIdle can be
124 // used to gate the destruction of the device. Prior to destroying a device,
125 // an application is responsible for destroying/freeing any Vulkan objects
126 // that were created using that device as the first parameter of the
127 // corresponding vkCreate* or vkAllocate* command."
128 RETURN_ON_VULKAN_ERROR(vkDeviceWaitIdle(device), "vkDeviceWaitIdle");
129
130 // Free and destroy.
131 vkFreeCommandBuffers(device, commandPool, commandBuffers.size(),
132 commandBuffers.data());
133 vkDestroyQueryPool(device, queryPool, nullptr);
134 vkDestroyCommandPool(device, commandPool, nullptr);
135 vkFreeDescriptorSets(device, descriptorPool, descriptorSets.size(),
136 descriptorSets.data());
137 vkDestroyDescriptorPool(device, descriptorPool, nullptr);
138 vkDestroyPipeline(device, pipeline, nullptr);
139 vkDestroyPipelineLayout(device, pipelineLayout, nullptr);
140 for (auto &descriptorSetLayout : descriptorSetLayouts) {
141 vkDestroyDescriptorSetLayout(device, descriptorSetLayout, nullptr);
142 }
143 vkDestroyShaderModule(device, shaderModule, nullptr);
144
145 // For each descriptor set.
146 for (auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
147 auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
148 // For each descriptor binding.
149 for (auto &memoryBuffer : deviceMemoryBuffers) {
150 vkFreeMemory(device, memoryBuffer.deviceMemory, nullptr);
151 vkFreeMemory(device, memoryBuffer.hostMemory, nullptr);
152 vkDestroyBuffer(device, memoryBuffer.hostBuffer, nullptr);
153 vkDestroyBuffer(device, memoryBuffer.deviceBuffer, nullptr);
154 }
155 }
156
157 vkDestroyDevice(device, nullptr);
158 vkDestroyInstance(instance, nullptr);
159 return success();
160 }
161
run()162 LogicalResult VulkanRuntime::run() {
163 // Create logical device, shader module and memory buffers.
164 if (failed(createInstance()) || failed(createDevice()) ||
165 failed(createMemoryBuffers()) || failed(createShaderModule())) {
166 return failure();
167 }
168
169 // Descriptor bindings divided into sets. Each descriptor binding
170 // must have a layout binding attached into a descriptor set layout.
171 // Each layout set must be binded into a pipeline layout.
172 initDescriptorSetLayoutBindingMap();
173 if (failed(createDescriptorSetLayout()) || failed(createPipelineLayout()) ||
174 // Each descriptor set must be allocated from a descriptor pool.
175 failed(createComputePipeline()) || failed(createDescriptorPool()) ||
176 failed(allocateDescriptorSets()) || failed(setWriteDescriptors()) ||
177 // Create command buffer.
178 failed(createCommandPool()) || failed(createQueryPool()) ||
179 failed(createComputeCommandBuffer())) {
180 return failure();
181 }
182
183 // Get working queue.
184 vkGetDeviceQueue(device, queueFamilyIndex, 0, &queue);
185
186 if (failed(copyResource(/*deviceToHost=*/false)))
187 return failure();
188
189 auto submitStart = std::chrono::high_resolution_clock::now();
190 // Submit command buffer into the queue.
191 if (failed(submitCommandBuffersToQueue()))
192 return failure();
193 auto submitEnd = std::chrono::high_resolution_clock::now();
194
195 RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue), "vkQueueWaitIdle");
196 auto execEnd = std::chrono::high_resolution_clock::now();
197
198 auto submitDuration = std::chrono::duration_cast<std::chrono::microseconds>(
199 submitEnd - submitStart);
200 auto execDuration = std::chrono::duration_cast<std::chrono::microseconds>(
201 execEnd - submitEnd);
202
203 if (queryPool != VK_NULL_HANDLE) {
204 uint64_t timestamps[2];
205 RETURN_ON_VULKAN_ERROR(
206 vkGetQueryPoolResults(
207 device, queryPool, /*firstQuery=*/0, /*queryCount=*/2,
208 /*dataSize=*/sizeof(timestamps),
209 /*pData=*/reinterpret_cast<void *>(timestamps),
210 /*stride=*/sizeof(uint64_t),
211 VK_QUERY_RESULT_64_BIT | VK_QUERY_RESULT_WAIT_BIT),
212 "vkGetQueryPoolResults");
213 float microsec = (timestamps[1] - timestamps[0]) * timestampPeriod / 1000;
214 std::cout << "Compute shader execution time: " << std::setprecision(3)
215 << microsec << "us\n";
216 }
217
218 std::cout << "Command buffer submit time: " << submitDuration.count()
219 << "us\nWait idle time: " << execDuration.count() << "us\n";
220
221 return success();
222 }
223
createInstance()224 LogicalResult VulkanRuntime::createInstance() {
225 VkApplicationInfo applicationInfo = {};
226 applicationInfo.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
227 applicationInfo.pNext = nullptr;
228 applicationInfo.pApplicationName = "MLIR Vulkan runtime";
229 applicationInfo.applicationVersion = 0;
230 applicationInfo.pEngineName = "mlir";
231 applicationInfo.engineVersion = 0;
232 applicationInfo.apiVersion = VK_MAKE_VERSION(1, 0, 0);
233
234 VkInstanceCreateInfo instanceCreateInfo = {};
235 instanceCreateInfo.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
236 instanceCreateInfo.pNext = nullptr;
237 instanceCreateInfo.flags = 0;
238 instanceCreateInfo.pApplicationInfo = &applicationInfo;
239 instanceCreateInfo.enabledLayerCount = 0;
240 instanceCreateInfo.ppEnabledLayerNames = 0;
241 instanceCreateInfo.enabledExtensionCount = 0;
242 instanceCreateInfo.ppEnabledExtensionNames = 0;
243
244 RETURN_ON_VULKAN_ERROR(vkCreateInstance(&instanceCreateInfo, 0, &instance),
245 "vkCreateInstance");
246 return success();
247 }
248
createDevice()249 LogicalResult VulkanRuntime::createDevice() {
250 uint32_t physicalDeviceCount = 0;
251 RETURN_ON_VULKAN_ERROR(
252 vkEnumeratePhysicalDevices(instance, &physicalDeviceCount, 0),
253 "vkEnumeratePhysicalDevices");
254
255 std::vector<VkPhysicalDevice> physicalDevices(physicalDeviceCount);
256 RETURN_ON_VULKAN_ERROR(vkEnumeratePhysicalDevices(instance,
257 &physicalDeviceCount,
258 physicalDevices.data()),
259 "vkEnumeratePhysicalDevices");
260
261 RETURN_ON_VULKAN_ERROR(physicalDeviceCount ? VK_SUCCESS : VK_INCOMPLETE,
262 "physicalDeviceCount");
263
264 // TODO: find the best device.
265 physicalDevice = physicalDevices.front();
266 if (failed(getBestComputeQueue()))
267 return failure();
268
269 const float queuePriority = 1.0f;
270 VkDeviceQueueCreateInfo deviceQueueCreateInfo = {};
271 deviceQueueCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
272 deviceQueueCreateInfo.pNext = nullptr;
273 deviceQueueCreateInfo.flags = 0;
274 deviceQueueCreateInfo.queueFamilyIndex = queueFamilyIndex;
275 deviceQueueCreateInfo.queueCount = 1;
276 deviceQueueCreateInfo.pQueuePriorities = &queuePriority;
277
278 // Structure specifying parameters of a newly created device.
279 VkDeviceCreateInfo deviceCreateInfo = {};
280 deviceCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
281 deviceCreateInfo.pNext = nullptr;
282 deviceCreateInfo.flags = 0;
283 deviceCreateInfo.queueCreateInfoCount = 1;
284 deviceCreateInfo.pQueueCreateInfos = &deviceQueueCreateInfo;
285 deviceCreateInfo.enabledLayerCount = 0;
286 deviceCreateInfo.ppEnabledLayerNames = nullptr;
287 deviceCreateInfo.enabledExtensionCount = 0;
288 deviceCreateInfo.ppEnabledExtensionNames = nullptr;
289 deviceCreateInfo.pEnabledFeatures = nullptr;
290
291 RETURN_ON_VULKAN_ERROR(
292 vkCreateDevice(physicalDevice, &deviceCreateInfo, 0, &device),
293 "vkCreateDevice");
294
295 VkPhysicalDeviceMemoryProperties properties = {};
296 vkGetPhysicalDeviceMemoryProperties(physicalDevice, &properties);
297
298 // Try to find memory type with following properties:
299 // VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT bit specifies that memory allocated
300 // with this type can be mapped for host access using vkMapMemory;
301 // VK_MEMORY_PROPERTY_HOST_COHERENT_BIT bit specifies that the host cache
302 // management commands vkFlushMappedMemoryRanges and
303 // vkInvalidateMappedMemoryRanges are not needed to flush host writes to the
304 // device or make device writes visible to the host, respectively.
305 for (uint32_t i = 0, e = properties.memoryTypeCount; i < e; ++i) {
306 if ((VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT &
307 properties.memoryTypes[i].propertyFlags) &&
308 (VK_MEMORY_PROPERTY_HOST_COHERENT_BIT &
309 properties.memoryTypes[i].propertyFlags) &&
310 (memorySize <=
311 properties.memoryHeaps[properties.memoryTypes[i].heapIndex].size)) {
312 hostMemoryTypeIndex = i;
313 break;
314 }
315 }
316
317 // Find memory type memory type with VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT to be
318 // used on the device. This will allow better performance access for GPU with
319 // on device memory.
320 for (uint32_t i = 0, e = properties.memoryTypeCount; i < e; ++i) {
321 if ((VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT &
322 properties.memoryTypes[i].propertyFlags) &&
323 (memorySize <=
324 properties.memoryHeaps[properties.memoryTypes[i].heapIndex].size)) {
325 deviceMemoryTypeIndex = i;
326 break;
327 }
328 }
329
330 RETURN_ON_VULKAN_ERROR((hostMemoryTypeIndex == VK_MAX_MEMORY_TYPES ||
331 deviceMemoryTypeIndex == VK_MAX_MEMORY_TYPES)
332 ? VK_INCOMPLETE
333 : VK_SUCCESS,
334 "invalid memoryTypeIndex");
335 return success();
336 }
337
getBestComputeQueue()338 LogicalResult VulkanRuntime::getBestComputeQueue() {
339 uint32_t queueFamilyPropertiesCount = 0;
340 vkGetPhysicalDeviceQueueFamilyProperties(physicalDevice,
341 &queueFamilyPropertiesCount, 0);
342
343 std::vector<VkQueueFamilyProperties> familyProperties(
344 queueFamilyPropertiesCount);
345 vkGetPhysicalDeviceQueueFamilyProperties(
346 physicalDevice, &queueFamilyPropertiesCount, familyProperties.data());
347
348 // VK_QUEUE_COMPUTE_BIT specifies that queues in this queue family support
349 // compute operations. Try to find a compute-only queue first if possible.
350 for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) {
351 auto flags = familyProperties[i].queueFlags;
352 if ((flags & VK_QUEUE_COMPUTE_BIT) && !(flags & VK_QUEUE_GRAPHICS_BIT)) {
353 queueFamilyIndex = i;
354 queueFamilyProperties = familyProperties[i];
355 return success();
356 }
357 }
358
359 // Otherwise use a queue that can also support graphics.
360 for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) {
361 auto flags = familyProperties[i].queueFlags;
362 if ((flags & VK_QUEUE_COMPUTE_BIT)) {
363 queueFamilyIndex = i;
364 queueFamilyProperties = familyProperties[i];
365 return success();
366 }
367 }
368
369 std::cerr << "cannot find valid queue";
370 return failure();
371 }
372
createMemoryBuffers()373 LogicalResult VulkanRuntime::createMemoryBuffers() {
374 // For each descriptor set.
375 for (const auto &resourceDataMapPair : resourceData) {
376 std::vector<VulkanDeviceMemoryBuffer> deviceMemoryBuffers;
377 const auto descriptorSetIndex = resourceDataMapPair.first;
378 const auto &resourceDataMap = resourceDataMapPair.second;
379
380 // For each descriptor binding.
381 for (const auto &resourceDataBindingPair : resourceDataMap) {
382 // Create device memory buffer.
383 VulkanDeviceMemoryBuffer memoryBuffer;
384 memoryBuffer.bindingIndex = resourceDataBindingPair.first;
385 VkDescriptorType descriptorType = {};
386 VkBufferUsageFlagBits bufferUsage = {};
387
388 // Check that descriptor set has storage class map.
389 const auto resourceStorageClassMapIt =
390 resourceStorageClassData.find(descriptorSetIndex);
391 if (resourceStorageClassMapIt == resourceStorageClassData.end()) {
392 std::cerr
393 << "cannot find storage class for resource in descriptor set: "
394 << descriptorSetIndex;
395 return failure();
396 }
397
398 // Check that specific descriptor binding has storage class.
399 const auto &resourceStorageClassMap = resourceStorageClassMapIt->second;
400 const auto resourceStorageClassIt =
401 resourceStorageClassMap.find(resourceDataBindingPair.first);
402 if (resourceStorageClassIt == resourceStorageClassMap.end()) {
403 std::cerr
404 << "cannot find storage class for resource with descriptor index: "
405 << resourceDataBindingPair.first;
406 return failure();
407 }
408
409 const auto resourceStorageClassBinding = resourceStorageClassIt->second;
410 if (failed(mapStorageClassToDescriptorType(resourceStorageClassBinding,
411 descriptorType)) ||
412 failed(mapStorageClassToBufferUsageFlag(resourceStorageClassBinding,
413 bufferUsage))) {
414 std::cerr << "storage class for resource with descriptor binding: "
415 << resourceDataBindingPair.first
416 << " in the descriptor set: " << descriptorSetIndex
417 << " is not supported ";
418 return failure();
419 }
420
421 // Set descriptor type for the specific device memory buffer.
422 memoryBuffer.descriptorType = descriptorType;
423 const auto bufferSize = resourceDataBindingPair.second.size;
424 memoryBuffer.bufferSize = bufferSize;
425 // Specify memory allocation info.
426 VkMemoryAllocateInfo memoryAllocateInfo = {};
427 memoryAllocateInfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
428 memoryAllocateInfo.pNext = nullptr;
429 memoryAllocateInfo.allocationSize = bufferSize;
430 memoryAllocateInfo.memoryTypeIndex = hostMemoryTypeIndex;
431
432 // Allocate device memory.
433 RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device, &memoryAllocateInfo, 0,
434 &memoryBuffer.hostMemory),
435 "vkAllocateMemory");
436 memoryAllocateInfo.memoryTypeIndex = deviceMemoryTypeIndex;
437 RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device, &memoryAllocateInfo, 0,
438 &memoryBuffer.deviceMemory),
439 "vkAllocateMemory");
440 void *payload;
441 RETURN_ON_VULKAN_ERROR(vkMapMemory(device, memoryBuffer.hostMemory, 0,
442 bufferSize, 0,
443 reinterpret_cast<void **>(&payload)),
444 "vkMapMemory");
445
446 // Copy host memory into the mapped area.
447 std::memcpy(payload, resourceDataBindingPair.second.ptr, bufferSize);
448 vkUnmapMemory(device, memoryBuffer.hostMemory);
449
450 VkBufferCreateInfo bufferCreateInfo = {};
451 bufferCreateInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
452 bufferCreateInfo.pNext = nullptr;
453 bufferCreateInfo.flags = 0;
454 bufferCreateInfo.size = bufferSize;
455 bufferCreateInfo.usage = bufferUsage | VK_BUFFER_USAGE_TRANSFER_DST_BIT |
456 VK_BUFFER_USAGE_TRANSFER_SRC_BIT;
457 bufferCreateInfo.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
458 bufferCreateInfo.queueFamilyIndexCount = 1;
459 bufferCreateInfo.pQueueFamilyIndices = &queueFamilyIndex;
460 RETURN_ON_VULKAN_ERROR(vkCreateBuffer(device, &bufferCreateInfo, 0,
461 &memoryBuffer.hostBuffer),
462 "vkCreateBuffer");
463 RETURN_ON_VULKAN_ERROR(vkCreateBuffer(device, &bufferCreateInfo, 0,
464 &memoryBuffer.deviceBuffer),
465 "vkCreateBuffer");
466
467 // Bind buffer and device memory.
468 RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device, memoryBuffer.hostBuffer,
469 memoryBuffer.hostMemory, 0),
470 "vkBindBufferMemory");
471 RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device,
472 memoryBuffer.deviceBuffer,
473 memoryBuffer.deviceMemory, 0),
474 "vkBindBufferMemory");
475
476 // Update buffer info.
477 memoryBuffer.bufferInfo.buffer = memoryBuffer.deviceBuffer;
478 memoryBuffer.bufferInfo.offset = 0;
479 memoryBuffer.bufferInfo.range = VK_WHOLE_SIZE;
480 deviceMemoryBuffers.push_back(memoryBuffer);
481 }
482
483 // Associate device memory buffers with a descriptor set.
484 deviceMemoryBufferMap[descriptorSetIndex] = deviceMemoryBuffers;
485 }
486 return success();
487 }
488
copyResource(bool deviceToHost)489 LogicalResult VulkanRuntime::copyResource(bool deviceToHost) {
490 VkCommandBufferAllocateInfo commandBufferAllocateInfo = {
491 VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO,
492 NULL,
493 commandPool,
494 VK_COMMAND_BUFFER_LEVEL_PRIMARY,
495 1,
496 };
497 VkCommandBuffer commandBuffer;
498 RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device,
499 &commandBufferAllocateInfo,
500 &commandBuffer),
501 "vkAllocateCommandBuffers");
502
503 VkCommandBufferBeginInfo commandBufferBeginInfo = {
504 VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO,
505 NULL,
506 0,
507 NULL,
508 };
509 RETURN_ON_VULKAN_ERROR(
510 vkBeginCommandBuffer(commandBuffer, &commandBufferBeginInfo),
511 "vkBeginCommandBuffer");
512
513 for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
514 std::vector<VkDescriptorSetLayoutBinding> descriptorSetLayoutBindings;
515 const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
516 for (const auto &memBuffer : deviceMemoryBuffers) {
517 VkBufferCopy copy = {0, 0, memBuffer.bufferSize};
518 if (deviceToHost)
519 vkCmdCopyBuffer(commandBuffer, memBuffer.deviceBuffer,
520 memBuffer.hostBuffer, 1, ©);
521 else
522 vkCmdCopyBuffer(commandBuffer, memBuffer.hostBuffer,
523 memBuffer.deviceBuffer, 1, ©);
524 }
525 }
526
527 RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer),
528 "vkEndCommandBuffer");
529 VkSubmitInfo submitInfo = {
530 VK_STRUCTURE_TYPE_SUBMIT_INFO,
531 NULL,
532 0,
533 NULL,
534 NULL,
535 1,
536 &commandBuffer,
537 0,
538 NULL,
539 };
540 submitInfo.pCommandBuffers = &commandBuffer;
541 RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue, 1, &submitInfo, VK_NULL_HANDLE),
542 "vkQueueSubmit");
543 RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue), "vkQueueWaitIdle");
544
545 vkFreeCommandBuffers(device, commandPool, 1, &commandBuffer);
546 return success();
547 }
548
createShaderModule()549 LogicalResult VulkanRuntime::createShaderModule() {
550 VkShaderModuleCreateInfo shaderModuleCreateInfo = {};
551 shaderModuleCreateInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
552 shaderModuleCreateInfo.pNext = nullptr;
553 shaderModuleCreateInfo.flags = 0;
554 // Set size in bytes.
555 shaderModuleCreateInfo.codeSize = binarySize;
556 // Set pointer to the binary shader.
557 shaderModuleCreateInfo.pCode = reinterpret_cast<uint32_t *>(binary);
558 RETURN_ON_VULKAN_ERROR(
559 vkCreateShaderModule(device, &shaderModuleCreateInfo, 0, &shaderModule),
560 "vkCreateShaderModule");
561 return success();
562 }
563
initDescriptorSetLayoutBindingMap()564 void VulkanRuntime::initDescriptorSetLayoutBindingMap() {
565 for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
566 std::vector<VkDescriptorSetLayoutBinding> descriptorSetLayoutBindings;
567 const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
568 const auto descriptorSetIndex = deviceMemoryBufferMapPair.first;
569
570 // Create a layout binding for each descriptor.
571 for (const auto &memBuffer : deviceMemoryBuffers) {
572 VkDescriptorSetLayoutBinding descriptorSetLayoutBinding = {};
573 descriptorSetLayoutBinding.binding = memBuffer.bindingIndex;
574 descriptorSetLayoutBinding.descriptorType = memBuffer.descriptorType;
575 descriptorSetLayoutBinding.descriptorCount = 1;
576 descriptorSetLayoutBinding.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
577 descriptorSetLayoutBinding.pImmutableSamplers = 0;
578 descriptorSetLayoutBindings.push_back(descriptorSetLayoutBinding);
579 }
580 descriptorSetLayoutBindingMap[descriptorSetIndex] =
581 descriptorSetLayoutBindings;
582 }
583 }
584
createDescriptorSetLayout()585 LogicalResult VulkanRuntime::createDescriptorSetLayout() {
586 for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
587 const auto descriptorSetIndex = deviceMemoryBufferMapPair.first;
588 const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
589 // Each descriptor in a descriptor set must be the same type.
590 VkDescriptorType descriptorType =
591 deviceMemoryBuffers.front().descriptorType;
592 const uint32_t descriptorSize = deviceMemoryBuffers.size();
593 const auto descriptorSetLayoutBindingIt =
594 descriptorSetLayoutBindingMap.find(descriptorSetIndex);
595
596 if (descriptorSetLayoutBindingIt == descriptorSetLayoutBindingMap.end()) {
597 std::cerr << "cannot find layout bindings for the set with number: "
598 << descriptorSetIndex;
599 return failure();
600 }
601
602 const auto &descriptorSetLayoutBindings =
603 descriptorSetLayoutBindingIt->second;
604 // Create descriptor set layout.
605 VkDescriptorSetLayout descriptorSetLayout = {};
606 VkDescriptorSetLayoutCreateInfo descriptorSetLayoutCreateInfo = {};
607
608 descriptorSetLayoutCreateInfo.sType =
609 VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
610 descriptorSetLayoutCreateInfo.pNext = nullptr;
611 descriptorSetLayoutCreateInfo.flags = 0;
612 // Amount of descriptor bindings in a layout set.
613 descriptorSetLayoutCreateInfo.bindingCount =
614 descriptorSetLayoutBindings.size();
615 descriptorSetLayoutCreateInfo.pBindings =
616 descriptorSetLayoutBindings.data();
617 RETURN_ON_VULKAN_ERROR(
618 vkCreateDescriptorSetLayout(device, &descriptorSetLayoutCreateInfo, 0,
619 &descriptorSetLayout),
620 "vkCreateDescriptorSetLayout");
621
622 descriptorSetLayouts.push_back(descriptorSetLayout);
623 descriptorSetInfoPool.push_back(
624 {descriptorSetIndex, descriptorSize, descriptorType});
625 }
626 return success();
627 }
628
createPipelineLayout()629 LogicalResult VulkanRuntime::createPipelineLayout() {
630 // Associate descriptor sets with a pipeline layout.
631 VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo = {};
632 pipelineLayoutCreateInfo.sType =
633 VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
634 pipelineLayoutCreateInfo.pNext = nullptr;
635 pipelineLayoutCreateInfo.flags = 0;
636 pipelineLayoutCreateInfo.setLayoutCount = descriptorSetLayouts.size();
637 pipelineLayoutCreateInfo.pSetLayouts = descriptorSetLayouts.data();
638 pipelineLayoutCreateInfo.pushConstantRangeCount = 0;
639 pipelineLayoutCreateInfo.pPushConstantRanges = 0;
640 RETURN_ON_VULKAN_ERROR(vkCreatePipelineLayout(device,
641 &pipelineLayoutCreateInfo, 0,
642 &pipelineLayout),
643 "vkCreatePipelineLayout");
644 return success();
645 }
646
createComputePipeline()647 LogicalResult VulkanRuntime::createComputePipeline() {
648 VkPipelineShaderStageCreateInfo stageInfo = {};
649 stageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
650 stageInfo.pNext = nullptr;
651 stageInfo.flags = 0;
652 stageInfo.stage = VK_SHADER_STAGE_COMPUTE_BIT;
653 stageInfo.module = shaderModule;
654 // Set entry point.
655 stageInfo.pName = entryPoint;
656 stageInfo.pSpecializationInfo = 0;
657
658 VkComputePipelineCreateInfo computePipelineCreateInfo = {};
659 computePipelineCreateInfo.sType =
660 VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
661 computePipelineCreateInfo.pNext = nullptr;
662 computePipelineCreateInfo.flags = 0;
663 computePipelineCreateInfo.stage = stageInfo;
664 computePipelineCreateInfo.layout = pipelineLayout;
665 computePipelineCreateInfo.basePipelineHandle = 0;
666 computePipelineCreateInfo.basePipelineIndex = 0;
667 RETURN_ON_VULKAN_ERROR(vkCreateComputePipelines(device, 0, 1,
668 &computePipelineCreateInfo, 0,
669 &pipeline),
670 "vkCreateComputePipelines");
671 return success();
672 }
673
createDescriptorPool()674 LogicalResult VulkanRuntime::createDescriptorPool() {
675 std::vector<VkDescriptorPoolSize> descriptorPoolSizes;
676 for (const auto &descriptorSetInfo : descriptorSetInfoPool) {
677 // For each descriptor set populate descriptor pool size.
678 VkDescriptorPoolSize descriptorPoolSize = {};
679 descriptorPoolSize.type = descriptorSetInfo.descriptorType;
680 descriptorPoolSize.descriptorCount = descriptorSetInfo.descriptorSize;
681 descriptorPoolSizes.push_back(descriptorPoolSize);
682 }
683
684 VkDescriptorPoolCreateInfo descriptorPoolCreateInfo = {};
685 descriptorPoolCreateInfo.sType =
686 VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
687 descriptorPoolCreateInfo.pNext = nullptr;
688 descriptorPoolCreateInfo.flags = 0;
689 descriptorPoolCreateInfo.maxSets = descriptorPoolSizes.size();
690 descriptorPoolCreateInfo.poolSizeCount = descriptorPoolSizes.size();
691 descriptorPoolCreateInfo.pPoolSizes = descriptorPoolSizes.data();
692 RETURN_ON_VULKAN_ERROR(vkCreateDescriptorPool(device,
693 &descriptorPoolCreateInfo, 0,
694 &descriptorPool),
695 "vkCreateDescriptorPool");
696 return success();
697 }
698
allocateDescriptorSets()699 LogicalResult VulkanRuntime::allocateDescriptorSets() {
700 VkDescriptorSetAllocateInfo descriptorSetAllocateInfo = {};
701 // Size of descriptor sets and descriptor layout sets is the same.
702 descriptorSets.resize(descriptorSetLayouts.size());
703 descriptorSetAllocateInfo.sType =
704 VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
705 descriptorSetAllocateInfo.pNext = nullptr;
706 descriptorSetAllocateInfo.descriptorPool = descriptorPool;
707 descriptorSetAllocateInfo.descriptorSetCount = descriptorSetLayouts.size();
708 descriptorSetAllocateInfo.pSetLayouts = descriptorSetLayouts.data();
709 RETURN_ON_VULKAN_ERROR(vkAllocateDescriptorSets(device,
710 &descriptorSetAllocateInfo,
711 descriptorSets.data()),
712 "vkAllocateDescriptorSets");
713 return success();
714 }
715
setWriteDescriptors()716 LogicalResult VulkanRuntime::setWriteDescriptors() {
717 if (descriptorSets.size() != descriptorSetInfoPool.size()) {
718 std::cerr << "Each descriptor set must have descriptor set information";
719 return failure();
720 }
721 // For each descriptor set.
722 auto descriptorSetIt = descriptorSets.begin();
723 // Each descriptor set is associated with descriptor set info.
724 for (const auto &descriptorSetInfo : descriptorSetInfoPool) {
725 // For each device memory buffer in the descriptor set.
726 const auto &deviceMemoryBuffers =
727 deviceMemoryBufferMap[descriptorSetInfo.descriptorSet];
728 for (const auto &memoryBuffer : deviceMemoryBuffers) {
729 // Structure describing descriptor sets to write to.
730 VkWriteDescriptorSet wSet = {};
731 wSet.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
732 wSet.pNext = nullptr;
733 // Descriptor set.
734 wSet.dstSet = *descriptorSetIt;
735 wSet.dstBinding = memoryBuffer.bindingIndex;
736 wSet.dstArrayElement = 0;
737 wSet.descriptorCount = 1;
738 wSet.descriptorType = memoryBuffer.descriptorType;
739 wSet.pImageInfo = nullptr;
740 wSet.pBufferInfo = &memoryBuffer.bufferInfo;
741 wSet.pTexelBufferView = nullptr;
742 vkUpdateDescriptorSets(device, 1, &wSet, 0, nullptr);
743 }
744 // Increment descriptor set iterator.
745 ++descriptorSetIt;
746 }
747 return success();
748 }
749
createCommandPool()750 LogicalResult VulkanRuntime::createCommandPool() {
751 VkCommandPoolCreateInfo commandPoolCreateInfo = {};
752 commandPoolCreateInfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
753 commandPoolCreateInfo.pNext = nullptr;
754 commandPoolCreateInfo.flags = 0;
755 commandPoolCreateInfo.queueFamilyIndex = queueFamilyIndex;
756 RETURN_ON_VULKAN_ERROR(vkCreateCommandPool(device, &commandPoolCreateInfo,
757 /*pAllocator=*/nullptr,
758 &commandPool),
759 "vkCreateCommandPool");
760 return success();
761 }
762
createQueryPool()763 LogicalResult VulkanRuntime::createQueryPool() {
764 // Return directly if timestamp query is not supported.
765 if (queueFamilyProperties.timestampValidBits == 0)
766 return success();
767
768 // Get timestamp period for this physical device.
769 VkPhysicalDeviceProperties deviceProperties = {};
770 vkGetPhysicalDeviceProperties(physicalDevice, &deviceProperties);
771 timestampPeriod = deviceProperties.limits.timestampPeriod;
772
773 // Create query pool.
774 VkQueryPoolCreateInfo queryPoolCreateInfo = {};
775 queryPoolCreateInfo.sType = VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO;
776 queryPoolCreateInfo.pNext = nullptr;
777 queryPoolCreateInfo.flags = 0;
778 queryPoolCreateInfo.queryType = VK_QUERY_TYPE_TIMESTAMP;
779 queryPoolCreateInfo.queryCount = 2;
780 queryPoolCreateInfo.pipelineStatistics = 0;
781 RETURN_ON_VULKAN_ERROR(vkCreateQueryPool(device, &queryPoolCreateInfo,
782 /*pAllocator=*/nullptr, &queryPool),
783 "vkCreateQueryPool");
784
785 return success();
786 }
787
createComputeCommandBuffer()788 LogicalResult VulkanRuntime::createComputeCommandBuffer() {
789 VkCommandBufferAllocateInfo commandBufferAllocateInfo = {};
790 commandBufferAllocateInfo.sType =
791 VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
792 commandBufferAllocateInfo.pNext = nullptr;
793 commandBufferAllocateInfo.commandPool = commandPool;
794 commandBufferAllocateInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
795 commandBufferAllocateInfo.commandBufferCount = 1;
796
797 VkCommandBuffer commandBuffer;
798 RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device,
799 &commandBufferAllocateInfo,
800 &commandBuffer),
801 "vkAllocateCommandBuffers");
802
803 VkCommandBufferBeginInfo commandBufferBeginInfo = {};
804 commandBufferBeginInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
805 commandBufferBeginInfo.pNext = nullptr;
806 commandBufferBeginInfo.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
807 commandBufferBeginInfo.pInheritanceInfo = nullptr;
808
809 // Commands begin.
810 RETURN_ON_VULKAN_ERROR(
811 vkBeginCommandBuffer(commandBuffer, &commandBufferBeginInfo),
812 "vkBeginCommandBuffer");
813
814 if (queryPool != VK_NULL_HANDLE)
815 vkCmdResetQueryPool(commandBuffer, queryPool, 0, 2);
816
817 vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
818 vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE,
819 pipelineLayout, 0, descriptorSets.size(),
820 descriptorSets.data(), 0, 0);
821 // Get a timestamp before invoking the compute shader.
822 if (queryPool != VK_NULL_HANDLE)
823 vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT,
824 queryPool, 0);
825 vkCmdDispatch(commandBuffer, numWorkGroups.x, numWorkGroups.y,
826 numWorkGroups.z);
827 // Get another timestamp after invoking the compute shader.
828 if (queryPool != VK_NULL_HANDLE)
829 vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT,
830 queryPool, 1);
831
832 // Commands end.
833 RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer),
834 "vkEndCommandBuffer");
835
836 commandBuffers.push_back(commandBuffer);
837 return success();
838 }
839
submitCommandBuffersToQueue()840 LogicalResult VulkanRuntime::submitCommandBuffersToQueue() {
841 VkSubmitInfo submitInfo = {};
842 submitInfo.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
843 submitInfo.pNext = nullptr;
844 submitInfo.waitSemaphoreCount = 0;
845 submitInfo.pWaitSemaphores = 0;
846 submitInfo.pWaitDstStageMask = 0;
847 submitInfo.commandBufferCount = commandBuffers.size();
848 submitInfo.pCommandBuffers = commandBuffers.data();
849 submitInfo.signalSemaphoreCount = 0;
850 submitInfo.pSignalSemaphores = nullptr;
851 RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue, 1, &submitInfo, 0),
852 "vkQueueSubmit");
853 return success();
854 }
855
updateHostMemoryBuffers()856 LogicalResult VulkanRuntime::updateHostMemoryBuffers() {
857 // First copy back the data to the staging buffer.
858 copyResource(/*deviceToHost=*/true);
859
860 // For each descriptor set.
861 for (auto &resourceDataMapPair : resourceData) {
862 auto &resourceDataMap = resourceDataMapPair.second;
863 auto &deviceMemoryBuffers =
864 deviceMemoryBufferMap[resourceDataMapPair.first];
865 // For each device memory buffer in the set.
866 for (auto &deviceMemoryBuffer : deviceMemoryBuffers) {
867 if (resourceDataMap.count(deviceMemoryBuffer.bindingIndex)) {
868 void *payload;
869 auto &hostMemoryBuffer =
870 resourceDataMap[deviceMemoryBuffer.bindingIndex];
871 RETURN_ON_VULKAN_ERROR(vkMapMemory(device,
872 deviceMemoryBuffer.hostMemory, 0,
873 hostMemoryBuffer.size, 0,
874 reinterpret_cast<void **>(&payload)),
875 "vkMapMemory");
876 std::memcpy(hostMemoryBuffer.ptr, payload, hostMemoryBuffer.size);
877 vkUnmapMemory(device, deviceMemoryBuffer.hostMemory);
878 }
879 }
880 }
881 return success();
882 }
883