ComputeTests.cpp 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. // Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
  2. // SPDX-FileCopyrightText: 2025 Jorrit Rouwe
  3. // SPDX-License-Identifier: MIT
  4. #include "UnitTestFramework.h"
  5. #if defined(JPH_USE_DX12) || defined(JPH_USE_MTL) || defined(JPH_USE_VK)
  6. #include <Jolt/Compute/ComputeSystem.h>
  7. #include <Jolt/Shaders/TestCompute.h>
  8. #include <Jolt/Core/IncludeWindows.h>
  9. JPH_SUPPRESS_WARNINGS_STD_BEGIN
  10. #include <fstream>
  11. #include <filesystem>
  12. #ifdef JPH_PLATFORM_LINUX
  13. #include <unistd.h>
  14. #endif
  15. JPH_SUPPRESS_WARNINGS_STD_END
  16. #if defined(JPH_PLATFORM_MACOS) || defined(JPH_PLATFORM_IOS)
  17. #include <CoreFoundation/CoreFoundation.h>
  18. #endif
  19. TEST_SUITE("ComputeTests")
  20. {
  21. static const char *cInvalidShaderName = "InvalidShader";
  22. static const char *cInvalidShaderCode = "invalid_shader_code";
  23. static void RunTests(ComputeSystem *inComputeSystem)
  24. {
  25. inComputeSystem->mShaderLoader = [](const char *inName, Array<uint8> &outData, String &outError) {
  26. // Special case to test what happens when an invalid file is returned
  27. if (strstr(inName, cInvalidShaderName) != nullptr)
  28. {
  29. outData.assign(cInvalidShaderCode, cInvalidShaderCode + strlen(cInvalidShaderCode));
  30. return true;
  31. }
  32. #if defined(JPH_PLATFORM_MACOS) || defined(JPH_PLATFORM_IOS)
  33. // In macOS the shaders are copied to the bundle
  34. CFBundleRef bundle = CFBundleGetMainBundle();
  35. CFURLRef resources = CFBundleCopyResourcesDirectoryURL(bundle);
  36. CFURLRef absolute = CFURLCopyAbsoluteURL(resources);
  37. CFRelease(resources);
  38. CFStringRef path_string = CFURLCopyFileSystemPath(absolute, kCFURLPOSIXPathStyle);
  39. CFRelease(absolute);
  40. char path[PATH_MAX];
  41. CFStringGetCString(path_string, path, PATH_MAX, kCFStringEncodingUTF8);
  42. CFRelease(path_string);
  43. String base_path = String(path) + "/Jolt/Shaders/";
  44. #else
  45. // On other platforms, start searching up from the application path
  46. #ifdef JPH_PLATFORM_WINDOWS
  47. char application_path[MAX_PATH] = { 0 };
  48. GetModuleFileName(nullptr, application_path, MAX_PATH);
  49. #elif defined(JPH_PLATFORM_LINUX)
  50. char application_path[PATH_MAX] = { 0 };
  51. int count = readlink("/proc/self/exe", application_path, PATH_MAX);
  52. if (count > 0)
  53. application_path[count] = 0;
  54. #else
  55. #error Unsupported platform
  56. #endif
  57. String base_path;
  58. filesystem::path shader_path(application_path);
  59. while (!shader_path.empty())
  60. {
  61. filesystem::path parent_path = shader_path.parent_path();
  62. if (parent_path == shader_path)
  63. break;
  64. shader_path = parent_path;
  65. filesystem::path full_path = shader_path / "Jolt" / "Shaders" / "";
  66. if (filesystem::exists(full_path))
  67. {
  68. base_path = String(full_path.string());
  69. break;
  70. }
  71. }
  72. #endif
  73. // Open file
  74. std::ifstream input((base_path + inName).c_str(), std::ios::in | std::ios::binary);
  75. if (!input.is_open())
  76. {
  77. outError = String("Could not open shader file: ") + base_path + inName;
  78. #if defined(JPH_PLATFORM_MACOS) || defined(JPH_PLATFORM_IOS)
  79. outError += "\nThis can fail on macOS when dxc or spirv-cross could not be found so the shaders could not be compiled.";
  80. #endif
  81. return false;
  82. }
  83. // Read contents of file
  84. input.seekg(0, ios_base::end);
  85. ifstream::pos_type length = input.tellg();
  86. input.seekg(0, ios_base::beg);
  87. outData.resize(size_t(length));
  88. if (length == 0)
  89. return true;
  90. input.read((char *)&outData[0], length);
  91. return true;
  92. };
  93. // Test failing shader creation
  94. {
  95. ComputeShaderResult shader_result = inComputeSystem->CreateComputeShader("NonExistingShader", 64);
  96. CHECK(shader_result.HasError());
  97. }
  98. constexpr uint32 cNumElements = 1234; // Not a multiple of cTestComputeGroupSize
  99. constexpr uint32 cNumIterations = 10;
  100. constexpr JPH_float3 cFloat3Value = JPH_float3(0, 0, 0);
  101. constexpr JPH_float3 cFloat3Value2 = JPH_float3(0, 13, 0);
  102. constexpr uint32 cUIntValue = 7;
  103. constexpr uint32 cUploadValue = 42;
  104. // Can't change context buffer while commands are queued, so create multiple constant buffers
  105. Ref<ComputeBuffer> context[cNumIterations];
  106. for (uint32 iter = 0; iter < cNumIterations; ++iter)
  107. {
  108. ComputeBufferResult buffer_result = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::ConstantBuffer, 1, sizeof(TestComputeContext));
  109. CHECK(!buffer_result.HasError());
  110. context[iter] = buffer_result.Get();
  111. }
  112. CHECK(context != nullptr);
  113. // Create an upload buffer
  114. ComputeBufferResult upload_buffer_result = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::UploadBuffer, 1, sizeof(uint32));
  115. CHECK(!upload_buffer_result.HasError());
  116. Ref<ComputeBuffer> upload_buffer = upload_buffer_result.Get();
  117. CHECK(upload_buffer != nullptr);
  118. uint32 *upload_data = upload_buffer->Map<uint32>(ComputeBuffer::EMode::Write);
  119. upload_data[0] = cUploadValue;
  120. upload_buffer->Unmap();
  121. // Create a read buffer
  122. UnitTestRandom rnd;
  123. Array<uint32> optional_data(cNumElements);
  124. for (uint32 &d : optional_data)
  125. d = rnd();
  126. ComputeBufferResult optional_buffer_result = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::Buffer, cNumElements, sizeof(uint32), optional_data.data());
  127. CHECK(!optional_buffer_result.HasError());
  128. Ref<ComputeBuffer> optional_buffer = optional_buffer_result.Get();
  129. CHECK(optional_buffer != nullptr);
  130. // Create a read-write buffer
  131. ComputeBufferResult buffer_result = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::RWBuffer, cNumElements, sizeof(uint32));
  132. CHECK(!buffer_result.HasError());
  133. Ref<ComputeBuffer> buffer = buffer_result.Get();
  134. CHECK(buffer != nullptr);
  135. // Create a read back buffer
  136. ComputeBufferResult readback_buffer_result = buffer->CreateReadBackBuffer();
  137. CHECK(!readback_buffer_result.HasError());
  138. Ref<ComputeBuffer> readback_buffer = readback_buffer_result.Get();
  139. CHECK(readback_buffer != nullptr);
  140. // Create the shader
  141. ComputeShaderResult shader_result = inComputeSystem->CreateComputeShader("TestCompute", cTestComputeGroupSize);
  142. if (shader_result.HasError())
  143. {
  144. Trace("Shader could not be created: %s", shader_result.GetError().c_str());
  145. return;
  146. }
  147. Ref<ComputeShader> shader = shader_result.Get();
  148. CHECK(shader != nullptr);
  149. // Create the queue
  150. ComputeQueueResult queue_result = inComputeSystem->CreateComputeQueue();
  151. CHECK(!queue_result.HasError());
  152. Ref<ComputeQueue> queue = queue_result.Get();
  153. CHECK(queue != nullptr);
  154. // Schedule work
  155. for (uint32 iter = 0; iter < cNumIterations; ++iter)
  156. {
  157. // Fill in the context
  158. TestComputeContext *value = context[iter]->Map<TestComputeContext>(ComputeBuffer::EMode::Write);
  159. value->cFloat3Value = cFloat3Value;
  160. value->cUIntValue = cUIntValue;
  161. value->cFloat3Value2 = cFloat3Value2;
  162. value->cUIntValue2 = iter;
  163. value->cNumElements = cNumElements;
  164. context[iter]->Unmap();
  165. queue->SetShader(shader);
  166. queue->SetConstantBuffer("gContext", context[iter]);
  167. context[iter] = nullptr; // Release the reference to ensure the queue keeps ownership
  168. queue->SetBuffer("gOptionalData", optional_buffer);
  169. optional_buffer = nullptr; // Release the reference so we test that the queue keeps ownership and that in the 2nd iteration we can set a null buffer
  170. queue->SetBuffer("gUploadData", upload_buffer);
  171. queue->SetRWBuffer("gData", buffer);
  172. queue->Dispatch((cNumElements + cTestComputeGroupSize - 1) / cTestComputeGroupSize);
  173. }
  174. // Run all queued commands
  175. queue->ScheduleReadback(readback_buffer, buffer);
  176. queue->ExecuteAndWait();
  177. // Calculate the expected result
  178. Array<uint32> expected_data(cNumElements);
  179. for (uint32 iter = 0; iter < cNumIterations; ++iter)
  180. {
  181. // Copy of the shader logic
  182. uint cUIntValue2 = iter;
  183. if (cUIntValue2 == 0)
  184. {
  185. // First write, uses optional data and tests that the packing of float3/uint3's works
  186. for (uint32 i = 0; i < cNumElements; ++i)
  187. expected_data[i] = optional_data[i] + int(cFloat3Value2.y) + cUploadValue;
  188. }
  189. else
  190. {
  191. // Read-modify-write gData
  192. for (uint32 i = 0; i < cNumElements; ++i)
  193. expected_data[i] = (expected_data[i] + cUIntValue) * cUIntValue2;
  194. }
  195. }
  196. // Compare computed data with expected data
  197. uint32 *data = readback_buffer->Map<uint32>(ComputeBuffer::EMode::Read);
  198. for (uint32 i = 0; i < cNumElements; ++i)
  199. CHECK(data[i] == expected_data[i]);
  200. readback_buffer->Unmap();
  201. }
  202. #ifdef JPH_USE_DX12
  203. TEST_CASE("TestComputeDX12")
  204. {
  205. ComputeSystemResult compute_system = CreateComputeSystemDX12();
  206. CHECK(!compute_system.HasError());
  207. if (!compute_system.HasError())
  208. {
  209. CHECK(compute_system.Get() != nullptr);
  210. RunTests(compute_system.Get());
  211. // Test failing shader compilation
  212. {
  213. ComputeShaderResult shader_result = compute_system.Get()->CreateComputeShader(cInvalidShaderName, 64);
  214. CHECK(shader_result.HasError());
  215. CHECK(strstr(shader_result.GetError().c_str(), cInvalidShaderCode) != nullptr); // Assume that the error message contains the invalid code
  216. }
  217. }
  218. }
  219. #endif // JPH_USE_DX12
  220. #ifdef JPH_USE_MTL
  221. TEST_CASE("TestComputeMTL")
  222. {
  223. ComputeSystemResult compute_system = CreateComputeSystemMTL();
  224. CHECK(!compute_system.HasError());
  225. if (!compute_system.HasError())
  226. {
  227. CHECK(compute_system.Get() != nullptr);
  228. RunTests(compute_system.Get());
  229. }
  230. }
  231. #endif // JPH_USE_MTL
  232. #ifdef JPH_USE_VK
  233. TEST_CASE("TestComputeVK")
  234. {
  235. ComputeSystemResult compute_system = CreateComputeSystemVK();
  236. CHECK(!compute_system.HasError());
  237. if (!compute_system.HasError())
  238. {
  239. CHECK(compute_system.Get() != nullptr);
  240. RunTests(compute_system.Get());
  241. }
  242. }
  243. #endif // JPH_USE_VK
  244. }
  245. #endif // defined(JPH_USE_DX12) || defined(JPH_USE_MTL) || defined(JPH_USE_VK)