ComputeShaderDX12.h 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. // Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
  2. // SPDX-FileCopyrightText: 2025 Jorrit Rouwe
  3. // SPDX-License-Identifier: MIT
  4. #pragma once
  5. #ifdef JPH_USE_DX12
  6. #include <Jolt/Compute/ComputeShader.h>
  7. #include <Jolt/Compute/DX12/IncludeDX12.h>
  8. #include <Jolt/Core/UnorderedMap.h>
  9. JPH_NAMESPACE_BEGIN
  10. /// Compute shader handle for DirectX
  11. class JPH_EXPORT ComputeShaderDX12 : public ComputeShader
  12. {
  13. public:
  14. JPH_OVERRIDE_NEW_DELETE
  15. /// Constructor
  16. ComputeShaderDX12(ComPtr<ID3DBlob> inShader, ComPtr<ID3D12RootSignature> inRootSignature, ComPtr<ID3D12PipelineState> inPipelineState, Array<String> &&inBindingNames, UnorderedMap<string_view, uint> &&inNameToIndex, uint32 inGroupSizeX, uint32 inGroupSizeY, uint32 inGroupSizeZ) :
  17. ComputeShader(inGroupSizeX, inGroupSizeY, inGroupSizeZ),
  18. mShader(inShader),
  19. mRootSignature(inRootSignature),
  20. mPipelineState(inPipelineState),
  21. mBindingNames(std::move(inBindingNames)),
  22. mNameToIndex(std::move(inNameToIndex))
  23. {
  24. }
  25. /// Get index of shader parameter
  26. uint NameToIndex(const char *inName) const
  27. {
  28. UnorderedMap<string_view, uint>::const_iterator it = mNameToIndex.find(inName);
  29. JPH_ASSERT(it != mNameToIndex.end());
  30. return it->second;
  31. }
  32. /// Getters
  33. ID3D12PipelineState * GetPipelineState() const { return mPipelineState.Get(); }
  34. ID3D12RootSignature * GetRootSignature() const { return mRootSignature.Get(); }
  35. private:
  36. ComPtr<ID3DBlob> mShader; ///< The compiled shader
  37. ComPtr<ID3D12RootSignature> mRootSignature; ///< The root signature for this shader
  38. ComPtr<ID3D12PipelineState> mPipelineState; ///< The pipeline state object for this shader
  39. Array<String> mBindingNames; ///< A list of binding names, mNameToIndex points to these strings
  40. UnorderedMap<string_view, uint> mNameToIndex; ///< Maps names to indices for the shader parameters, using a string_view so we can do find() without an allocation
  41. };
  42. JPH_NAMESPACE_END
  43. #endif // JPH_USE_DX12