ComputeSystemMTL.mm 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. // Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
  2. // SPDX-FileCopyrightText: 2025 Jorrit Rouwe
  3. // SPDX-License-Identifier: MIT
  4. #include <Jolt/Jolt.h>
  5. #ifdef JPH_USE_MTL
  6. #include <Jolt/Compute/MTL/ComputeSystemMTL.h>
  7. #include <Jolt/Compute/MTL/ComputeBufferMTL.h>
  8. #include <Jolt/Compute/MTL/ComputeShaderMTL.h>
  9. #include <Jolt/Compute/MTL/ComputeQueueMTL.h>
  10. JPH_NAMESPACE_BEGIN
  11. JPH_IMPLEMENT_RTTI_VIRTUAL(ComputeSystemMTL)
  12. {
  13. JPH_ADD_BASE_CLASS(ComputeSystemMTL, ComputeSystem)
  14. }
  15. bool ComputeSystemMTL::Initialize(id<MTLDevice> inDevice)
  16. {
  17. mDevice = [inDevice retain];
  18. return true;
  19. }
  20. void ComputeSystemMTL::Shutdown()
  21. {
  22. [mShaderLibrary release];
  23. [mDevice release];
  24. }
  25. ComputeShaderResult ComputeSystemMTL::CreateComputeShader(const char *inName, uint32 inGroupSizeX, uint32 inGroupSizeY, uint32 inGroupSizeZ)
  26. {
  27. ComputeShaderResult result;
  28. if (mShaderLibrary == nil)
  29. {
  30. // Load the shader library containing all shaders
  31. Array<uint8> *data = new Array<uint8>();
  32. String error;
  33. if (!mShaderLoader("Jolt.metallib", *data, error))
  34. {
  35. result.SetError(error);
  36. delete data;
  37. return result;
  38. }
  39. // Convert to dispatch data
  40. dispatch_data_t data_dispatch = dispatch_data_create(data->data(), data->size(), nullptr, ^{ delete data; });
  41. // Create the library
  42. NSError *ns_error = nullptr;
  43. mShaderLibrary = [mDevice newLibraryWithData: data_dispatch error: &ns_error];
  44. if (ns_error != nil)
  45. {
  46. result.SetError("Failed to laod shader library");
  47. return result;
  48. }
  49. }
  50. // Get the shader function
  51. id<MTLFunction> function = [mShaderLibrary newFunctionWithName: [NSString stringWithCString: inName encoding: NSUTF8StringEncoding]];
  52. if (function == nil)
  53. {
  54. result.SetError("Failed to instantiate compute shader");
  55. return result;
  56. }
  57. // Create the pipeline
  58. NSError *error = nil;
  59. MTLComputePipelineReflection *reflection = nil;
  60. id<MTLComputePipelineState> pipeline_state = [mDevice newComputePipelineStateWithFunction: function options: MTLPipelineOptionBindingInfo | MTLPipelineOptionBufferTypeInfo reflection: &reflection error: &error];
  61. if (error != nil || pipeline_state == nil)
  62. {
  63. result.SetError("Failed to create compute pipeline");
  64. [function release];
  65. return result;
  66. }
  67. result.Set(new ComputeShaderMTL(pipeline_state, reflection, inGroupSizeX, inGroupSizeY, inGroupSizeZ));
  68. return result;
  69. }
  70. ComputeBufferResult ComputeSystemMTL::CreateComputeBuffer(ComputeBuffer::EType inType, uint64 inSize, uint inStride, const void *inData)
  71. {
  72. ComputeBufferResult result;
  73. Ref<ComputeBufferMTL> buffer = new ComputeBufferMTL(this, inType, inSize, inStride);
  74. if (!buffer->Initialize(inData))
  75. {
  76. result.SetError("Failed to create compute buffer");
  77. return result;
  78. }
  79. result.Set(buffer.GetPtr());
  80. return result;
  81. }
  82. ComputeQueueResult ComputeSystemMTL::CreateComputeQueue()
  83. {
  84. ComputeQueueResult result;
  85. result.Set(new ComputeQueueMTL(mDevice));
  86. return result;
  87. }
  88. JPH_NAMESPACE_END
  89. #endif // JPH_USE_MTL