ComputeTests.cpp 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. // Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
  2. // SPDX-FileCopyrightText: 2025 Jorrit Rouwe
  3. // SPDX-License-Identifier: MIT
  4. #include "UnitTestFramework.h"
  5. #include <Jolt/Compute/ComputeSystem.h>
  6. #include <Jolt/Compute/CPU/ComputeSystemCPU.h>
  7. #include <Jolt/Shaders/ShaderCore.h>
  8. #include <Jolt/Shaders/TestComputeBindings.h>
  9. #include <Jolt/Core/IncludeWindows.h>
  10. #include <Jolt/Core/RTTI.h>
  11. JPH_SUPPRESS_WARNINGS_STD_BEGIN
  12. #include <fstream>
  13. #include <filesystem>
  14. #ifdef JPH_PLATFORM_LINUX
  15. #include <unistd.h>
  16. #endif
  17. JPH_SUPPRESS_WARNINGS_STD_END
  18. #if defined(JPH_PLATFORM_MACOS) || defined(JPH_PLATFORM_IOS)
  19. #include <CoreFoundation/CoreFoundation.h>
  20. #endif
  21. JPH_DECLARE_REGISTER_SHADER(TestCompute)
  22. TEST_SUITE("ComputeTests")
  23. {
  24. static const char *cInvalidShaderName = "InvalidShader";
  25. static const char *cInvalidShaderCode = "invalid_shader_code";
  26. static void RunTests(ComputeSystem *inComputeSystem)
  27. {
  28. inComputeSystem->mShaderLoader = [](const char *inName, Array<uint8> &outData, String &outError) {
  29. // Special case to test what happens when an invalid file is returned
  30. if (strstr(inName, cInvalidShaderName) != nullptr)
  31. {
  32. outData.assign(cInvalidShaderCode, cInvalidShaderCode + strlen(cInvalidShaderCode));
  33. return true;
  34. }
  35. #if defined(JPH_PLATFORM_MACOS) || defined(JPH_PLATFORM_IOS)
  36. // In macOS the shaders are copied to the bundle
  37. CFBundleRef bundle = CFBundleGetMainBundle();
  38. CFURLRef resources = CFBundleCopyResourcesDirectoryURL(bundle);
  39. CFURLRef absolute = CFURLCopyAbsoluteURL(resources);
  40. CFRelease(resources);
  41. CFStringRef path_string = CFURLCopyFileSystemPath(absolute, kCFURLPOSIXPathStyle);
  42. CFRelease(absolute);
  43. char path[PATH_MAX];
  44. CFStringGetCString(path_string, path, PATH_MAX, kCFStringEncodingUTF8);
  45. CFRelease(path_string);
  46. String base_path = String(path) + "/Jolt/Shaders/";
  47. #else
  48. // On other platforms, start searching up from the application path
  49. #ifdef JPH_PLATFORM_WINDOWS
  50. char application_path[MAX_PATH] = { 0 };
  51. GetModuleFileName(nullptr, application_path, MAX_PATH);
  52. #elif defined(JPH_PLATFORM_LINUX)
  53. char application_path[PATH_MAX] = { 0 };
  54. int count = readlink("/proc/self/exe", application_path, PATH_MAX);
  55. if (count > 0)
  56. application_path[count] = 0;
  57. #else
  58. // Not implemented
  59. const char *application_path = "";
  60. #endif
  61. String base_path;
  62. filesystem::path shader_path(application_path);
  63. while (!shader_path.empty())
  64. {
  65. filesystem::path parent_path = shader_path.parent_path();
  66. if (parent_path == shader_path)
  67. break;
  68. shader_path = parent_path;
  69. filesystem::path full_path = shader_path / "Jolt" / "Shaders" / "";
  70. if (filesystem::exists(full_path))
  71. {
  72. base_path = String(full_path.string());
  73. break;
  74. }
  75. }
  76. #endif
  77. // Open file
  78. std::ifstream input((base_path + inName).c_str(), std::ios::in | std::ios::binary);
  79. if (!input.is_open())
  80. {
  81. outError = String("Could not open shader file: ") + base_path + inName;
  82. #if defined(JPH_PLATFORM_MACOS) || defined(JPH_PLATFORM_IOS)
  83. outError += "\nThis can fail on macOS when dxc or spirv-cross could not be found so the shaders could not be compiled.";
  84. #endif
  85. return false;
  86. }
  87. // Read contents of file
  88. input.seekg(0, ios_base::end);
  89. ifstream::pos_type length = input.tellg();
  90. input.seekg(0, ios_base::beg);
  91. outData.resize(size_t(length));
  92. if (length == 0)
  93. return true;
  94. input.read((char *)&outData[0], length);
  95. return true;
  96. };
  97. // Test failing shader creation
  98. {
  99. ComputeShaderResult shader_result = inComputeSystem->CreateComputeShader("NonExistingShader", 64);
  100. CHECK(shader_result.HasError());
  101. }
  102. constexpr uint32 cNumElements = 1234; // Not a multiple of cTestComputeGroupSize
  103. constexpr uint32 cNumIterations = 10;
  104. constexpr JPH_float3 cFloat3Value = JPH_float3(0, 0, 0);
  105. constexpr JPH_float3 cFloat3Value2 = JPH_float3(0, 13, 0);
  106. constexpr uint32 cUIntValue = 7;
  107. constexpr uint32 cUploadValue = 42;
  108. // Can't change context buffer while commands are queued, so create multiple constant buffers
  109. Ref<ComputeBuffer> context[cNumIterations];
  110. for (uint32 iter = 0; iter < cNumIterations; ++iter)
  111. {
  112. ComputeBufferResult buffer_result = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::ConstantBuffer, 1, sizeof(TestComputeContext));
  113. CHECK(!buffer_result.HasError());
  114. context[iter] = buffer_result.Get();
  115. }
  116. CHECK(context != nullptr);
  117. // Create an upload buffer
  118. ComputeBufferResult upload_buffer_result = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::UploadBuffer, 1, sizeof(uint32));
  119. CHECK(!upload_buffer_result.HasError());
  120. Ref<ComputeBuffer> upload_buffer = upload_buffer_result.Get();
  121. CHECK(upload_buffer != nullptr);
  122. uint32 *upload_data = upload_buffer->Map<uint32>(ComputeBuffer::EMode::Write);
  123. upload_data[0] = cUploadValue;
  124. upload_buffer->Unmap();
  125. // Create a read buffer
  126. UnitTestRandom rnd;
  127. Array<uint32> optional_data(cNumElements);
  128. for (uint32 &d : optional_data)
  129. d = rnd();
  130. ComputeBufferResult optional_buffer_result = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::Buffer, cNumElements, sizeof(uint32), optional_data.data());
  131. CHECK(!optional_buffer_result.HasError());
  132. Ref<ComputeBuffer> optional_buffer = optional_buffer_result.Get();
  133. CHECK(optional_buffer != nullptr);
  134. // Create a read-write buffer
  135. ComputeBufferResult buffer_result = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::RWBuffer, cNumElements, sizeof(uint32));
  136. CHECK(!buffer_result.HasError());
  137. Ref<ComputeBuffer> buffer = buffer_result.Get();
  138. CHECK(buffer != nullptr);
  139. // Create a read back buffer
  140. ComputeBufferResult readback_buffer_result = buffer->CreateReadBackBuffer();
  141. CHECK(!readback_buffer_result.HasError());
  142. Ref<ComputeBuffer> readback_buffer = readback_buffer_result.Get();
  143. CHECK(readback_buffer != nullptr);
  144. // Create the shader
  145. ComputeShaderResult shader_result = inComputeSystem->CreateComputeShader("TestCompute", cTestComputeGroupSize);
  146. if (shader_result.HasError())
  147. {
  148. Trace("Shader could not be created: %s", shader_result.GetError().c_str());
  149. return;
  150. }
  151. Ref<ComputeShader> shader = shader_result.Get();
  152. CHECK(shader != nullptr);
  153. // Create the queue
  154. ComputeQueueResult queue_result = inComputeSystem->CreateComputeQueue();
  155. CHECK(!queue_result.HasError());
  156. Ref<ComputeQueue> queue = queue_result.Get();
  157. CHECK(queue != nullptr);
  158. // Schedule work
  159. for (uint32 iter = 0; iter < cNumIterations; ++iter)
  160. {
  161. // Fill in the context
  162. TestComputeContext *value = context[iter]->Map<TestComputeContext>(ComputeBuffer::EMode::Write);
  163. value->cFloat3Value = cFloat3Value;
  164. value->cUIntValue = cUIntValue;
  165. value->cFloat3Value2 = cFloat3Value2;
  166. value->cUIntValue2 = iter;
  167. value->cNumElements = cNumElements;
  168. context[iter]->Unmap();
  169. queue->SetShader(shader);
  170. queue->SetConstantBuffer("gContext", context[iter]);
  171. context[iter] = nullptr; // Release the reference to ensure the queue keeps ownership
  172. queue->SetBuffer("gOptionalData", optional_buffer);
  173. optional_buffer = nullptr; // Release the reference so we test that the queue keeps ownership and that in the 2nd iteration we can set a null buffer
  174. queue->SetBuffer("gUploadData", upload_buffer);
  175. queue->SetRWBuffer("gData", buffer);
  176. queue->Dispatch((cNumElements + cTestComputeGroupSize - 1) / cTestComputeGroupSize);
  177. }
  178. // Run all queued commands
  179. queue->ScheduleReadback(readback_buffer, buffer);
  180. queue->ExecuteAndWait();
  181. // Calculate the expected result
  182. Array<uint32> expected_data(cNumElements);
  183. for (uint32 iter = 0; iter < cNumIterations; ++iter)
  184. {
  185. // Copy of the shader logic
  186. uint cUIntValue2 = iter;
  187. if (cUIntValue2 == 0)
  188. {
  189. // First write, uses optional data and tests that the packing of float3/uint3's works
  190. for (uint32 i = 0; i < cNumElements; ++i)
  191. expected_data[i] = optional_data[i] + int(cFloat3Value2.y) + cUploadValue;
  192. }
  193. else
  194. {
  195. // Read-modify-write gData
  196. for (uint32 i = 0; i < cNumElements; ++i)
  197. expected_data[i] = (expected_data[i] + cUIntValue) * cUIntValue2;
  198. }
  199. }
  200. // Compare computed data with expected data
  201. uint32 *data = readback_buffer->Map<uint32>(ComputeBuffer::EMode::Read);
  202. for (uint32 i = 0; i < cNumElements; ++i)
  203. CHECK(data[i] == expected_data[i]);
  204. readback_buffer->Unmap();
  205. }
  206. #ifdef JPH_USE_DX12
  207. TEST_CASE("TestComputeDX12")
  208. {
  209. ComputeSystemResult compute_system = CreateComputeSystemDX12();
  210. CHECK(!compute_system.HasError());
  211. if (!compute_system.HasError())
  212. {
  213. CHECK(compute_system.Get() != nullptr);
  214. RunTests(compute_system.Get());
  215. // Test failing shader compilation
  216. {
  217. ComputeShaderResult shader_result = compute_system.Get()->CreateComputeShader(cInvalidShaderName, 64);
  218. CHECK(shader_result.HasError());
  219. CHECK(strstr(shader_result.GetError().c_str(), cInvalidShaderCode) != nullptr); // Assume that the error message contains the invalid code
  220. }
  221. }
  222. }
  223. #endif // JPH_USE_DX12
  224. #ifdef JPH_USE_MTL
  225. TEST_CASE("TestComputeMTL")
  226. {
  227. ComputeSystemResult compute_system = CreateComputeSystemMTL();
  228. CHECK(!compute_system.HasError());
  229. if (!compute_system.HasError())
  230. {
  231. CHECK(compute_system.Get() != nullptr);
  232. RunTests(compute_system.Get());
  233. }
  234. }
  235. #endif // JPH_USE_MTL
  236. #ifdef JPH_USE_VK
  237. TEST_CASE("TestComputeVK")
  238. {
  239. ComputeSystemResult compute_system = CreateComputeSystemVK();
  240. CHECK(!compute_system.HasError());
  241. if (!compute_system.HasError())
  242. {
  243. CHECK(compute_system.Get() != nullptr);
  244. RunTests(compute_system.Get());
  245. }
  246. }
  247. #endif // JPH_USE_VK
  248. TEST_CASE("TestComputeCPU")
  249. {
  250. ComputeSystemResult compute_system = CreateComputeSystemCPU();
  251. CHECK(!compute_system.HasError());
  252. if (!compute_system.HasError())
  253. {
  254. CHECK(compute_system.Get() != nullptr);
  255. JPH_REGISTER_SHADER(StaticCast<ComputeSystemCPU>(compute_system.Get()), TestCompute);
  256. RunTests(compute_system.Get());
  257. }
  258. }
  259. }