ComputeQueueMTL.mm 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. // Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
  2. // SPDX-FileCopyrightText: 2025 Jorrit Rouwe
  3. // SPDX-License-Identifier: MIT
  4. #include <Jolt/Jolt.h>
  5. #ifdef JPH_USE_MTL
  6. #include <Jolt/Compute/MTL/ComputeQueueMTL.h>
  7. #include <Jolt/Compute/MTL/ComputeShaderMTL.h>
  8. #include <Jolt/Compute/MTL/ComputeBufferMTL.h>
  9. #include <Jolt/Compute/MTL/ComputeSystemMTL.h>
  10. JPH_NAMESPACE_BEGIN
  11. ComputeQueueMTL::~ComputeQueueMTL()
  12. {
  13. Wait();
  14. [mCommandQueue release];
  15. }
  16. ComputeQueueMTL::ComputeQueueMTL(id<MTLDevice> inDevice)
  17. {
  18. // Create the command queue
  19. mCommandQueue = [inDevice newCommandQueue];
  20. }
  21. void ComputeQueueMTL::BeginCommandBuffer()
  22. {
  23. if (mCommandBuffer == nil)
  24. {
  25. // Start a new command buffer
  26. mCommandBuffer = [mCommandQueue commandBuffer];
  27. mComputeEncoder = [mCommandBuffer computeCommandEncoder];
  28. }
  29. }
  30. void ComputeQueueMTL::SetShader(const ComputeShader *inShader)
  31. {
  32. BeginCommandBuffer();
  33. mShader = static_cast<const ComputeShaderMTL *>(inShader);
  34. [mComputeEncoder setComputePipelineState: mShader->GetPipelineState()];
  35. }
  36. void ComputeQueueMTL::SetConstantBuffer(const char *inName, const ComputeBuffer *inBuffer)
  37. {
  38. if (inBuffer == nullptr)
  39. return;
  40. JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::ConstantBuffer);
  41. BeginCommandBuffer();
  42. const ComputeBufferMTL *buffer = static_cast<const ComputeBufferMTL *>(inBuffer);
  43. [mComputeEncoder setBuffer: buffer->GetBuffer() offset: 0 atIndex: mShader->NameToBindingIndex(inName)];
  44. }
  45. void ComputeQueueMTL::SetBuffer(const char *inName, const ComputeBuffer *inBuffer)
  46. {
  47. if (inBuffer == nullptr)
  48. return;
  49. JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::UploadBuffer || inBuffer->GetType() == ComputeBuffer::EType::Buffer || inBuffer->GetType() == ComputeBuffer::EType::RWBuffer);
  50. BeginCommandBuffer();
  51. const ComputeBufferMTL *buffer = static_cast<const ComputeBufferMTL *>(inBuffer);
  52. [mComputeEncoder setBuffer: buffer->GetBuffer() offset: 0 atIndex: mShader->NameToBindingIndex(inName)];
  53. }
  54. void ComputeQueueMTL::SetRWBuffer(const char *inName, ComputeBuffer *inBuffer, EBarrier inBarrier)
  55. {
  56. if (inBuffer == nullptr)
  57. return;
  58. JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::RWBuffer);
  59. BeginCommandBuffer();
  60. const ComputeBufferMTL *buffer = static_cast<const ComputeBufferMTL *>(inBuffer);
  61. [mComputeEncoder setBuffer: buffer->GetBuffer() offset: 0 atIndex: mShader->NameToBindingIndex(inName)];
  62. }
  63. void ComputeQueueMTL::ScheduleReadback(ComputeBuffer *inDst, const ComputeBuffer *inSrc)
  64. {
  65. JPH_ASSERT(inDst == inSrc); // Since ComputeBuffer::CreateReadBackBuffer returns the same buffer, we don't need to copy
  66. }
  67. void ComputeQueueMTL::Dispatch(uint inThreadGroupsX, uint inThreadGroupsY, uint inThreadGroupsZ)
  68. {
  69. BeginCommandBuffer();
  70. MTLSize thread_groups = MTLSizeMake(inThreadGroupsX, inThreadGroupsY, inThreadGroupsZ);
  71. MTLSize group_size = MTLSizeMake(mShader->GetGroupSizeX(), mShader->GetGroupSizeY(), mShader->GetGroupSizeZ());
  72. [mComputeEncoder dispatchThreadgroups: thread_groups threadsPerThreadgroup: group_size];
  73. }
  74. void ComputeQueueMTL::Execute()
  75. {
  76. // End command buffer
  77. if (mCommandBuffer == nil)
  78. return;
  79. [mComputeEncoder endEncoding];
  80. [mCommandBuffer commit];
  81. mShader = nullptr;
  82. mIsExecuting = true;
  83. }
  84. void ComputeQueueMTL::Wait()
  85. {
  86. if (!mIsExecuting)
  87. return;
  88. [mCommandBuffer waitUntilCompleted];
  89. mComputeEncoder = nil;
  90. mCommandBuffer = nil;
  91. mIsExecuting = false;
  92. }
  93. JPH_NAMESPACE_END
  94. #endif // JPH_USE_MTL