2
0

ComputeSystemVKImpl.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. // Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
  2. // SPDX-FileCopyrightText: 2025 Jorrit Rouwe
  3. // SPDX-License-Identifier: MIT
  4. #include <Jolt/Jolt.h>
  5. #ifdef JPH_USE_VK
  6. #include <Jolt/Compute/VK/ComputeSystemVKImpl.h>
  7. #include <Jolt/Core/QuickSort.h>
  8. JPH_NAMESPACE_BEGIN
  9. JPH_IMPLEMENT_RTTI_VIRTUAL(ComputeSystemVKImpl)
  10. {
  11. JPH_ADD_BASE_CLASS(ComputeSystemVKImpl, ComputeSystemVKWithAllocator)
  12. }
  13. #ifdef JPH_DEBUG
  14. static VKAPI_ATTR VkBool32 VKAPI_CALL sVulkanDebugCallback(VkDebugUtilsMessageSeverityFlagBitsEXT inSeverity, [[maybe_unused]] VkDebugUtilsMessageTypeFlagsEXT inType, const VkDebugUtilsMessengerCallbackDataEXT *inCallbackData, [[maybe_unused]] void *inUserData)
  15. {
  16. if (inSeverity & (VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT))
  17. Trace("VK: %s", inCallbackData->pMessage);
  18. JPH_ASSERT((inSeverity & VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT) == 0);
  19. return VK_FALSE;
  20. }
  21. #endif // JPH_DEBUG
  22. ComputeSystemVKImpl::~ComputeSystemVKImpl()
  23. {
  24. ComputeSystemVK::Shutdown();
  25. if (mDevice != VK_NULL_HANDLE)
  26. vkDestroyDevice(mDevice, nullptr);
  27. #ifdef JPH_DEBUG
  28. PFN_vkDestroyDebugUtilsMessengerEXT vkDestroyDebugUtilsMessengerEXT = (PFN_vkDestroyDebugUtilsMessengerEXT)(void *)vkGetInstanceProcAddr(mInstance, "vkDestroyDebugUtilsMessengerEXT");
  29. if (mInstance != VK_NULL_HANDLE && mDebugMessenger != VK_NULL_HANDLE && vkDestroyDebugUtilsMessengerEXT != nullptr)
  30. vkDestroyDebugUtilsMessengerEXT(mInstance, mDebugMessenger, nullptr);
  31. #endif
  32. if (mInstance != VK_NULL_HANDLE)
  33. vkDestroyInstance(mInstance, nullptr);
  34. }
  35. bool ComputeSystemVKImpl::Initialize(ComputeSystemResult &outResult)
  36. {
  37. // Required instance extensions
  38. Array<const char *> required_instance_extensions;
  39. required_instance_extensions.push_back(VK_KHR_SURFACE_EXTENSION_NAME);
  40. required_instance_extensions.push_back(VK_EXT_DEBUG_UTILS_EXTENSION_NAME);
  41. #ifdef JPH_PLATFORM_MACOS
  42. required_instance_extensions.push_back("VK_KHR_portability_enumeration");
  43. required_instance_extensions.push_back("VK_KHR_get_physical_device_properties2");
  44. #endif
  45. GetInstanceExtensions(required_instance_extensions);
  46. // Required device extensions
  47. Array<const char *> required_device_extensions;
  48. required_device_extensions.push_back(VK_EXT_SCALAR_BLOCK_LAYOUT_EXTENSION_NAME);
  49. #ifdef JPH_PLATFORM_MACOS
  50. required_device_extensions.push_back("VK_KHR_portability_subset"); // VK_KHR_PORTABILITY_SUBSET_EXTENSION_NAME
  51. #endif
  52. GetDeviceExtensions(required_device_extensions);
  53. // Query supported instance extensions
  54. uint32 instance_extension_count = 0;
  55. if (VKFailed(vkEnumerateInstanceExtensionProperties(nullptr, &instance_extension_count, nullptr), outResult))
  56. return false;
  57. Array<VkExtensionProperties> instance_extensions;
  58. instance_extensions.resize(instance_extension_count);
  59. if (VKFailed(vkEnumerateInstanceExtensionProperties(nullptr, &instance_extension_count, instance_extensions.data()), outResult))
  60. return false;
  61. // Query supported validation layers
  62. uint32 validation_layer_count;
  63. vkEnumerateInstanceLayerProperties(&validation_layer_count, nullptr);
  64. Array<VkLayerProperties> validation_layers(validation_layer_count);
  65. vkEnumerateInstanceLayerProperties(&validation_layer_count, validation_layers.data());
  66. VkApplicationInfo app_info = {};
  67. app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
  68. app_info.apiVersion = VK_API_VERSION_1_1;
  69. // Create Vulkan instance
  70. VkInstanceCreateInfo instance_create_info = {};
  71. instance_create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
  72. #ifdef JPH_PLATFORM_MACOS
  73. instance_create_info.flags = VK_INSTANCE_CREATE_ENUMERATE_PORTABILITY_BIT_KHR;
  74. #endif
  75. instance_create_info.pApplicationInfo = &app_info;
  76. #ifdef JPH_DEBUG
  77. // Enable validation layer if supported
  78. const char *desired_validation_layers[] = { "VK_LAYER_KHRONOS_validation" };
  79. for (const VkLayerProperties &p : validation_layers)
  80. if (strcmp(desired_validation_layers[0], p.layerName) == 0)
  81. {
  82. instance_create_info.enabledLayerCount = 1;
  83. instance_create_info.ppEnabledLayerNames = desired_validation_layers;
  84. break;
  85. }
  86. // Setup debug messenger callback if the extension is supported
  87. VkDebugUtilsMessengerCreateInfoEXT messenger_create_info = {};
  88. for (const VkExtensionProperties &ext : instance_extensions)
  89. if (strcmp(VK_EXT_DEBUG_UTILS_EXTENSION_NAME, ext.extensionName) == 0)
  90. {
  91. messenger_create_info.sType = VK_STRUCTURE_TYPE_DEBUG_UTILS_MESSENGER_CREATE_INFO_EXT;
  92. messenger_create_info.messageSeverity = VK_DEBUG_UTILS_MESSAGE_SEVERITY_VERBOSE_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_SEVERITY_INFO_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT;
  93. messenger_create_info.messageType = VK_DEBUG_UTILS_MESSAGE_TYPE_GENERAL_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_TYPE_VALIDATION_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_TYPE_PERFORMANCE_BIT_EXT;
  94. messenger_create_info.pfnUserCallback = sVulkanDebugCallback;
  95. instance_create_info.pNext = &messenger_create_info;
  96. required_instance_extensions.push_back(VK_EXT_DEBUG_UTILS_EXTENSION_NAME);
  97. break;
  98. }
  99. #endif
  100. instance_create_info.enabledExtensionCount = (uint32)required_instance_extensions.size();
  101. instance_create_info.ppEnabledExtensionNames = required_instance_extensions.data();
  102. if (VKFailed(vkCreateInstance(&instance_create_info, nullptr, &mInstance), outResult))
  103. return false;
  104. #ifdef JPH_DEBUG
  105. // Finalize debug messenger callback
  106. PFN_vkCreateDebugUtilsMessengerEXT vkCreateDebugUtilsMessengerEXT = (PFN_vkCreateDebugUtilsMessengerEXT)(std::uintptr_t)vkGetInstanceProcAddr(mInstance, "vkCreateDebugUtilsMessengerEXT");
  107. if (vkCreateDebugUtilsMessengerEXT != nullptr)
  108. if (VKFailed(vkCreateDebugUtilsMessengerEXT(mInstance, &messenger_create_info, nullptr, &mDebugMessenger), outResult))
  109. return false;
  110. #endif
  111. // Notify that instance has been created
  112. OnInstanceCreated();
  113. // Select device
  114. uint32 device_count = 0;
  115. if (VKFailed(vkEnumeratePhysicalDevices(mInstance, &device_count, nullptr), outResult))
  116. return false;
  117. Array<VkPhysicalDevice> devices;
  118. devices.resize(device_count);
  119. if (VKFailed(vkEnumeratePhysicalDevices(mInstance, &device_count, devices.data()), outResult))
  120. return false;
  121. struct Device
  122. {
  123. VkPhysicalDevice mPhysicalDevice;
  124. String mName;
  125. VkSurfaceFormatKHR mFormat;
  126. uint32 mGraphicsQueueIndex;
  127. uint32 mPresentQueueIndex;
  128. uint32 mComputeQueueIndex;
  129. int mScore;
  130. };
  131. Array<Device> available_devices;
  132. for (VkPhysicalDevice device : devices)
  133. {
  134. // Get device properties
  135. VkPhysicalDeviceProperties properties;
  136. vkGetPhysicalDeviceProperties(device, &properties);
  137. // Test if it is an appropriate type
  138. int score = 0;
  139. switch (properties.deviceType)
  140. {
  141. case VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU:
  142. score = 30;
  143. break;
  144. case VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU:
  145. score = 20;
  146. break;
  147. case VK_PHYSICAL_DEVICE_TYPE_VIRTUAL_GPU:
  148. score = 10;
  149. break;
  150. case VK_PHYSICAL_DEVICE_TYPE_CPU:
  151. score = 5;
  152. break;
  153. case VK_PHYSICAL_DEVICE_TYPE_OTHER:
  154. case VK_PHYSICAL_DEVICE_TYPE_MAX_ENUM:
  155. continue;
  156. }
  157. // Check if the device supports all our required extensions
  158. uint32 device_extension_count;
  159. vkEnumerateDeviceExtensionProperties(device, nullptr, &device_extension_count, nullptr);
  160. Array<VkExtensionProperties> available_extensions;
  161. available_extensions.resize(device_extension_count);
  162. vkEnumerateDeviceExtensionProperties(device, nullptr, &device_extension_count, available_extensions.data());
  163. int found_extensions = 0;
  164. for (const char *required_device_extension : required_device_extensions)
  165. for (const VkExtensionProperties &ext : available_extensions)
  166. if (strcmp(required_device_extension, ext.extensionName) == 0)
  167. {
  168. found_extensions++;
  169. break;
  170. }
  171. if (found_extensions != int(required_device_extensions.size()))
  172. continue;
  173. // Find the right queues
  174. uint32 queue_family_count = 0;
  175. vkGetPhysicalDeviceQueueFamilyProperties(device, &queue_family_count, nullptr);
  176. Array<VkQueueFamilyProperties> queue_families;
  177. queue_families.resize(queue_family_count);
  178. vkGetPhysicalDeviceQueueFamilyProperties(device, &queue_family_count, queue_families.data());
  179. uint32 graphics_queue = ~uint32(0);
  180. uint32 present_queue = ~uint32(0);
  181. uint32 compute_queue = ~uint32(0);
  182. for (uint32 i = 0; i < uint32(queue_families.size()); ++i)
  183. {
  184. if (queue_families[i].queueFlags & VK_QUEUE_GRAPHICS_BIT)
  185. {
  186. graphics_queue = i;
  187. if (queue_families[i].queueFlags & VK_QUEUE_COMPUTE_BIT)
  188. compute_queue = i;
  189. }
  190. if (HasPresentSupport(device, i))
  191. present_queue = i;
  192. if (graphics_queue != ~uint32(0) && present_queue != ~uint32(0) && compute_queue != ~uint32(0))
  193. break;
  194. }
  195. if (graphics_queue == ~uint32(0) || present_queue == ~uint32(0) || compute_queue == ~uint32(0))
  196. continue;
  197. // Select surface format
  198. VkSurfaceFormatKHR selected_format = SelectFormat(device);
  199. if (selected_format.format == VK_FORMAT_UNDEFINED)
  200. continue;
  201. // Add the device
  202. available_devices.push_back({ device, properties.deviceName, selected_format, graphics_queue, present_queue, compute_queue, score });
  203. }
  204. if (available_devices.empty())
  205. {
  206. outResult.SetError("No suitable Vulkan device found");
  207. return false;
  208. }
  209. // Sort the devices by score
  210. QuickSort(available_devices.begin(), available_devices.end(), [](const Device &inLHS, const Device &inRHS) {
  211. return inLHS.mScore > inRHS.mScore;
  212. });
  213. const Device &selected_device = available_devices[0];
  214. // Create device
  215. float queue_priority = 1.0f;
  216. VkDeviceQueueCreateInfo queue_create_info[3] = {};
  217. for (VkDeviceQueueCreateInfo &q : queue_create_info)
  218. {
  219. q.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
  220. q.queueCount = 1;
  221. q.pQueuePriorities = &queue_priority;
  222. }
  223. uint32 num_queues = 0;
  224. queue_create_info[num_queues++].queueFamilyIndex = selected_device.mGraphicsQueueIndex;
  225. bool found = false;
  226. for (uint32 i = 0; i < num_queues; ++i)
  227. if (queue_create_info[i].queueFamilyIndex == selected_device.mPresentQueueIndex)
  228. {
  229. found = true;
  230. break;
  231. }
  232. if (!found)
  233. queue_create_info[num_queues++].queueFamilyIndex = selected_device.mPresentQueueIndex;
  234. found = false;
  235. for (uint32 i = 0; i < num_queues; ++i)
  236. if (queue_create_info[i].queueFamilyIndex == selected_device.mComputeQueueIndex)
  237. {
  238. found = true;
  239. break;
  240. }
  241. if (!found)
  242. queue_create_info[num_queues++].queueFamilyIndex = selected_device.mComputeQueueIndex;
  243. VkPhysicalDeviceScalarBlockLayoutFeatures enable_scalar_block = {};
  244. enable_scalar_block.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SCALAR_BLOCK_LAYOUT_FEATURES;
  245. enable_scalar_block.scalarBlockLayout = VK_TRUE;
  246. VkPhysicalDeviceFeatures2 enabled_features2 = {};
  247. enabled_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
  248. GetEnabledFeatures(enabled_features2);
  249. enable_scalar_block.pNext = enabled_features2.pNext;
  250. enabled_features2.pNext = &enable_scalar_block;
  251. VkDeviceCreateInfo device_create_info = {};
  252. device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
  253. device_create_info.queueCreateInfoCount = num_queues;
  254. device_create_info.pQueueCreateInfos = queue_create_info;
  255. device_create_info.enabledLayerCount = instance_create_info.enabledLayerCount;
  256. device_create_info.ppEnabledLayerNames = instance_create_info.ppEnabledLayerNames;
  257. device_create_info.enabledExtensionCount = uint32(required_device_extensions.size());
  258. device_create_info.ppEnabledExtensionNames = required_device_extensions.data();
  259. device_create_info.pNext = &enabled_features2;
  260. device_create_info.pEnabledFeatures = nullptr;
  261. VkDevice device = VK_NULL_HANDLE;
  262. if (VKFailed(vkCreateDevice(selected_device.mPhysicalDevice, &device_create_info, nullptr, &device), outResult))
  263. return false;
  264. // Get the queues
  265. mGraphicsQueueIndex = selected_device.mGraphicsQueueIndex;
  266. mPresentQueueIndex = selected_device.mPresentQueueIndex;
  267. vkGetDeviceQueue(device, mGraphicsQueueIndex, 0, &mGraphicsQueue);
  268. vkGetDeviceQueue(device, mPresentQueueIndex, 0, &mPresentQueue);
  269. // Store selected format
  270. mSelectedFormat = selected_device.mFormat;
  271. // Initialize the compute system
  272. return ComputeSystemVK::Initialize(selected_device.mPhysicalDevice, device, selected_device.mComputeQueueIndex, outResult);
  273. }
  274. ComputeSystemResult CreateComputeSystemVK()
  275. {
  276. ComputeSystemResult result;
  277. Ref<ComputeSystemVKImpl> compute = new ComputeSystemVKImpl;
  278. if (!compute->Initialize(result))
  279. return result;
  280. result.Set(compute.GetPtr());
  281. return result;
  282. }
  283. JPH_NAMESPACE_END
  284. #endif // JPH_USE_VK