ComputeTests.cpp 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. // Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
  2. // SPDX-FileCopyrightText: 2025 Jorrit Rouwe
  3. // SPDX-License-Identifier: MIT
  4. #include "UnitTestFramework.h"
  5. #if defined(JPH_USE_DX12) || defined(JPH_USE_MTL) || defined(JPH_USE_VK)
  6. #include <Jolt/Compute/ComputeSystem.h>
  7. #include <Jolt/Shaders/TestCompute.h>
  8. #include <Jolt/Core/IncludeWindows.h>
  9. JPH_SUPPRESS_WARNINGS_STD_BEGIN
  10. #include <fstream>
  11. #include <filesystem>
  12. #ifdef JPH_PLATFORM_LINUX
  13. #include <unistd.h>
  14. #endif
  15. JPH_SUPPRESS_WARNINGS_STD_END
  16. #if defined(JPH_PLATFORM_MACOS) || defined(JPH_PLATFORM_IOS)
  17. #include <CoreFoundation/CoreFoundation.h>
  18. #endif
  19. TEST_SUITE("ComputeTests")
  20. {
  21. static void RunTests(ComputeSystem *inComputeSystem)
  22. {
  23. inComputeSystem->mShaderLoader = [](const char *inName, Array<uint8> &outData) {
  24. #if defined(JPH_PLATFORM_MACOS) || defined(JPH_PLATFORM_IOS)
  25. // In macOS the shaders are copied to the bundle
  26. CFBundleRef bundle = CFBundleGetMainBundle();
  27. CFURLRef resources = CFBundleCopyResourcesDirectoryURL(bundle);
  28. CFURLRef absolute = CFURLCopyAbsoluteURL(resources);
  29. CFRelease(resources);
  30. CFStringRef path_string = CFURLCopyFileSystemPath(absolute, kCFURLPOSIXPathStyle);
  31. CFRelease(absolute);
  32. char path[PATH_MAX];
  33. CFStringGetCString(path_string, path, PATH_MAX, kCFStringEncodingUTF8);
  34. CFRelease(path_string);
  35. String base_path = String(path) + "/Jolt/Shaders/";
  36. #else
  37. // On other platforms, start searching up from the application path
  38. #ifdef JPH_PLATFORM_WINDOWS
  39. char application_path[MAX_PATH] = { 0 };
  40. GetModuleFileName(nullptr, application_path, MAX_PATH);
  41. #elif defined(JPH_PLATFORM_LINUX)
  42. char application_path[PATH_MAX] = { 0 };
  43. int count = readlink("/proc/self/exe", application_path, PATH_MAX);
  44. if (count > 0)
  45. application_path[count] = 0;
  46. #else
  47. #error Unsupported platform
  48. #endif
  49. String base_path;
  50. filesystem::path shader_path(application_path);
  51. while (!shader_path.empty())
  52. {
  53. filesystem::path parent_path = shader_path.parent_path();
  54. if (parent_path == shader_path)
  55. break;
  56. shader_path = parent_path;
  57. filesystem::path full_path = shader_path / "Jolt" / "Shaders" / "";
  58. if (filesystem::exists(full_path))
  59. {
  60. base_path = String(full_path.string());
  61. break;
  62. }
  63. }
  64. #endif
  65. // Open file
  66. std::ifstream input((base_path + inName).c_str(), std::ios::in | std::ios::binary);
  67. if (!input.is_open())
  68. return false;
  69. // Read contents of file
  70. input.seekg(0, ios_base::end);
  71. ifstream::pos_type length = input.tellg();
  72. input.seekg(0, ios_base::beg);
  73. outData.resize(size_t(length));
  74. input.read((char *)&outData[0], length);
  75. return true;
  76. };
  77. constexpr uint32 cNumElements = 1234; // Not a multiple of cTestComputeGroupSize
  78. constexpr uint32 cNumIterations = 10;
  79. constexpr JPH_float3 cFloat3Value = JPH_float3(0, 0, 0);
  80. constexpr JPH_float3 cFloat3Value2 = JPH_float3(0, 13, 0);
  81. constexpr uint32 cUIntValue = 7;
  82. constexpr uint32 cUploadValue = 42;
  83. // Can't change context buffer while commands are queued, so create multiple constant buffers
  84. Ref<ComputeBuffer> context[cNumIterations];
  85. for (uint32 iter = 0; iter < cNumIterations; ++iter)
  86. context[iter] = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::ConstantBuffer, 1, sizeof(TestComputeContext));
  87. CHECK(context != nullptr);
  88. // Create an upload buffer
  89. Ref<ComputeBuffer> upload_buffer = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::UploadBuffer, 1, sizeof(uint32));
  90. CHECK(upload_buffer != nullptr);
  91. uint32 *upload_data = upload_buffer->Map<uint32>(ComputeBuffer::EMode::Write);
  92. upload_data[0] = cUploadValue;
  93. upload_buffer->Unmap();
  94. // Create a read buffer
  95. UnitTestRandom rnd;
  96. Array<uint32> optional_data(cNumElements);
  97. for (uint32 &d : optional_data)
  98. d = rnd();
  99. Ref<ComputeBuffer> optional_buffer = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::Buffer, cNumElements, sizeof(uint32), optional_data.data());
  100. CHECK(optional_buffer != nullptr);
  101. // Create a read-write buffer
  102. Ref<ComputeBuffer> buffer = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::RWBuffer, cNumElements, sizeof(uint32));
  103. CHECK(buffer != nullptr);
  104. // Create a read back buffer
  105. Ref<ComputeBuffer> readback_buffer = buffer->CreateReadBackBuffer();
  106. CHECK(readback_buffer != nullptr);
  107. // Create the shader
  108. Ref<ComputeShader> shader = inComputeSystem->CreateComputeShader("TestCompute", cTestComputeGroupSize);
  109. CHECK(shader != nullptr);
  110. if (shader == nullptr)
  111. {
  112. Trace("Shader could not be loaded. This can fail on macOS when dxc or spirv-cross could not be found so the shaders could not be compiled.");
  113. return;
  114. }
  115. // Create the queue
  116. Ref<ComputeQueue> queue = inComputeSystem->CreateComputeQueue();
  117. // Schedule work
  118. for (uint32 iter = 0; iter < cNumIterations; ++iter)
  119. {
  120. // Fill in the context
  121. TestComputeContext *value = context[iter]->Map<TestComputeContext>(ComputeBuffer::EMode::Write);
  122. value->cFloat3Value = cFloat3Value;
  123. value->cUIntValue = cUIntValue;
  124. value->cFloat3Value2 = cFloat3Value2;
  125. value->cUIntValue2 = iter;
  126. value->cNumElements = cNumElements;
  127. context[iter]->Unmap();
  128. queue->SetShader(shader);
  129. queue->SetConstantBuffer("gContext", context[iter]);
  130. context[iter] = nullptr; // Release the reference to ensure the queue keeps ownership
  131. queue->SetBuffer("gOptionalData", optional_buffer);
  132. optional_buffer = nullptr; // Release the reference so we test that the queue keeps ownership and that in the 2nd iteration we can set a null buffer
  133. queue->SetBuffer("gUploadData", upload_buffer);
  134. queue->SetRWBuffer("gData", buffer);
  135. queue->Dispatch((cNumElements + cTestComputeGroupSize - 1) / cTestComputeGroupSize);
  136. }
  137. // Run all queued commands
  138. queue->ScheduleReadback(readback_buffer, buffer);
  139. queue->ExecuteAndWait();
  140. // Calculate the expected result
  141. Array<uint32> expected_data(cNumElements);
  142. for (uint32 iter = 0; iter < cNumIterations; ++iter)
  143. {
  144. // Copy of the shader logic
  145. uint cUIntValue2 = iter;
  146. if (cUIntValue2 == 0)
  147. {
  148. // First write, uses optional data and tests that the packing of float3/uint3's works
  149. for (uint32 i = 0; i < cNumElements; ++i)
  150. expected_data[i] = optional_data[i] + int(cFloat3Value2.y) + cUploadValue;
  151. }
  152. else
  153. {
  154. // Read-modify-write gData
  155. for (uint32 i = 0; i < cNumElements; ++i)
  156. expected_data[i] = (expected_data[i] + cUIntValue) * cUIntValue2;
  157. }
  158. }
  159. // Compare computed data with expected data
  160. uint32 *data = readback_buffer->Map<uint32>(ComputeBuffer::EMode::Read);
  161. for (uint32 i = 0; i < cNumElements; ++i)
  162. CHECK(data[i] == expected_data[i]);
  163. readback_buffer->Unmap();
  164. }
  165. #ifdef JPH_USE_DX12
  166. TEST_CASE("TestComputeDX12")
  167. {
  168. Ref<ComputeSystem> compute_system = CreateComputeSystemDX12();
  169. CHECK(compute_system != nullptr);
  170. if (compute_system != nullptr)
  171. RunTests(compute_system);
  172. }
  173. #endif // JPH_USE_DX12
  174. #ifdef JPH_USE_MTL
  175. TEST_CASE("TestComputeMTL")
  176. {
  177. Ref<ComputeSystem> compute_system = CreateComputeSystemMTL();
  178. CHECK(compute_system != nullptr);
  179. if (compute_system != nullptr)
  180. RunTests(compute_system);
  181. }
  182. #endif // JPH_USE_MTL
  183. #ifdef JPH_USE_VK
  184. TEST_CASE("TestComputeVK")
  185. {
  186. Ref<ComputeSystem> compute_system = CreateComputeSystemVK();
  187. CHECK(compute_system != nullptr);
  188. if (compute_system != nullptr)
  189. RunTests(compute_system);
  190. }
  191. #endif // JPH_USE_VK
  192. }
  193. #endif // defined(JPH_USE_DX12) || defined(JPH_USE_MTL) || defined(JPH_USE_VK)