D3DShaderProgram.cpp 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. // Copyright (C) 2009-present, Panagiotis Christopoulos Charitos and contributors.
  2. // All rights reserved.
  3. // Code licensed under the BSD License.
  4. // http://www.anki3d.org/LICENSE
  5. #include <AnKi/Gr/D3D/D3DShaderProgram.h>
  6. #include <AnKi/Gr/D3D/D3DShader.h>
  7. #include <AnKi/Gr/BackendCommon/Functions.h>
  8. #include <AnKi/Gr/D3D/D3DDescriptor.h>
  9. #include <AnKi/Gr/D3D/D3DGraphicsState.h>
  10. namespace anki {
  11. ShaderProgram* ShaderProgram::newInstance(const ShaderProgramInitInfo& init)
  12. {
  13. ShaderProgramImpl* impl = anki::newInstance<ShaderProgramImpl>(GrMemoryPool::getSingleton(), init.getName());
  14. const Error err = impl->init(init);
  15. if(err)
  16. {
  17. deleteInstance(GrMemoryPool::getSingleton(), impl);
  18. impl = nullptr;
  19. }
  20. return impl;
  21. }
  22. ConstWeakArray<U8> ShaderProgram::getShaderGroupHandles() const
  23. {
  24. ANKI_ASSERT(!"TODO");
  25. return ConstWeakArray<U8>();
  26. }
  27. Buffer& ShaderProgram::getShaderGroupHandlesGpuBuffer() const
  28. {
  29. ANKI_ASSERT(!"TODO");
  30. void* ptr = nullptr;
  31. return *reinterpret_cast<Buffer*>(ptr);
  32. }
  33. ShaderProgramImpl::~ShaderProgramImpl()
  34. {
  35. safeRelease(m_compute.m_pipelineState);
  36. safeRelease(m_workGraph.m_stateObject);
  37. deleteInstance(GrMemoryPool::getSingleton(), m_graphics.m_pipelineFactory);
  38. }
  39. Error ShaderProgramImpl::init(const ShaderProgramInitInfo& inf)
  40. {
  41. ANKI_ASSERT(inf.isValid());
  42. // Create the shader references
  43. if(inf.m_computeShader)
  44. {
  45. m_shaders.emplaceBack(inf.m_computeShader);
  46. }
  47. else if(inf.m_graphicsShaders[ShaderType::kPixel])
  48. {
  49. for(Shader* s : inf.m_graphicsShaders)
  50. {
  51. if(s)
  52. {
  53. m_shaders.emplaceBack(s);
  54. }
  55. }
  56. }
  57. else if(inf.m_workGraph.m_shader)
  58. {
  59. m_shaders.emplaceBack(inf.m_workGraph.m_shader);
  60. }
  61. else
  62. {
  63. ANKI_ASSERT(!"TODO");
  64. }
  65. ANKI_ASSERT(m_shaders.getSize() > 0);
  66. for(ShaderPtr& shader : m_shaders)
  67. {
  68. m_shaderTypes |= ShaderTypeBit(1 << shader->getShaderType());
  69. }
  70. const Bool isGraphicsProg = !!(m_shaderTypes & ShaderTypeBit::kAllGraphics);
  71. const Bool isComputeProg = !!(m_shaderTypes & ShaderTypeBit::kCompute);
  72. const Bool isRtProg = !!(m_shaderTypes & ShaderTypeBit::kAllRayTracing);
  73. const Bool isWorkGraph = !!(m_shaderTypes & ShaderTypeBit::kWorkGraph);
  74. // Link reflection
  75. ShaderReflection refl;
  76. Bool firstLink = true;
  77. for(ShaderPtr& shader : m_shaders)
  78. {
  79. const ShaderImpl& simpl = static_cast<const ShaderImpl&>(*shader);
  80. if(firstLink)
  81. {
  82. refl = simpl.m_reflection;
  83. firstLink = false;
  84. }
  85. else
  86. {
  87. ANKI_CHECK(ShaderReflection::linkShaderReflection(refl, simpl.m_reflection, refl));
  88. }
  89. refl.validate();
  90. }
  91. m_refl = refl;
  92. // Create root signature
  93. ANKI_CHECK(RootSignatureFactory::getSingleton().getOrCreateRootSignature(refl, m_rootSignature));
  94. // Init the create infos
  95. if(isGraphicsProg)
  96. {
  97. for(U32 ishader = 0; ishader < m_shaders.getSize(); ++ishader)
  98. {
  99. const ShaderImpl& shaderImpl = static_cast<const ShaderImpl&>(*m_shaders[ishader]);
  100. m_graphics.m_shaderCreateInfos[shaderImpl.getShaderType()] = {.pShaderBytecode = shaderImpl.m_binary.getBegin(),
  101. .BytecodeLength = shaderImpl.m_binary.getSizeInBytes()};
  102. }
  103. }
  104. // Create the pipeline if compute
  105. if(isComputeProg)
  106. {
  107. const ShaderImpl& shaderImpl = static_cast<const ShaderImpl&>(*m_shaders[0]);
  108. D3D12_COMPUTE_PIPELINE_STATE_DESC desc = {};
  109. desc.pRootSignature = &m_rootSignature->getD3DRootSignature();
  110. desc.CS.BytecodeLength = shaderImpl.m_binary.getSizeInBytes();
  111. desc.CS.pShaderBytecode = shaderImpl.m_binary.getBegin();
  112. ANKI_D3D_CHECK(getDevice().CreateComputePipelineState(&desc, IID_PPV_ARGS(&m_compute.m_pipelineState)));
  113. }
  114. // Create the shader object if workgraph
  115. if(isWorkGraph)
  116. {
  117. const WChar* wgName = L"main";
  118. const ShaderImpl& shaderImpl = static_cast<const ShaderImpl&>(*m_shaders[0]);
  119. // Init sub-objects
  120. CD3DX12_STATE_OBJECT_DESC stateObj(D3D12_STATE_OBJECT_TYPE_EXECUTABLE);
  121. auto lib = stateObj.CreateSubobject<CD3DX12_DXIL_LIBRARY_SUBOBJECT>();
  122. CD3DX12_SHADER_BYTECODE libCode(shaderImpl.m_binary.getBegin(), shaderImpl.m_binary.getSizeInBytes());
  123. lib->SetDXILLibrary(&libCode);
  124. auto rootSigSubObj = stateObj.CreateSubobject<CD3DX12_GLOBAL_ROOT_SIGNATURE_SUBOBJECT>();
  125. rootSigSubObj->SetRootSignature(&m_rootSignature->getD3DRootSignature());
  126. auto wgSubObj = stateObj.CreateSubobject<CD3DX12_WORK_GRAPH_SUBOBJECT>();
  127. wgSubObj->IncludeAllAvailableNodes(); // Auto populate the graph
  128. wgSubObj->SetProgramName(wgName);
  129. GrDynamicArray<Array<WChar, 128>> nodeNames;
  130. nodeNames.resize(inf.m_workGraph.m_nodeSpecializations.getSize());
  131. for(U32 i = 0; i < inf.m_workGraph.m_nodeSpecializations.getSize(); ++i)
  132. {
  133. const WorkGraphNodeSpecialization& specialization = inf.m_workGraph.m_nodeSpecializations[i];
  134. specialization.m_nodeName.toWideChars(nodeNames[i].getBegin(), nodeNames[i].getSize());
  135. CD3DX12_BROADCASTING_LAUNCH_NODE_OVERRIDES* spec = wgSubObj->CreateBroadcastingLaunchNodeOverrides(nodeNames[i].getBegin());
  136. ANKI_ASSERT(specialization.m_maxNodeDispatchGrid > UVec3(0u));
  137. spec->MaxDispatchGrid(specialization.m_maxNodeDispatchGrid.x(), specialization.m_maxNodeDispatchGrid.y(),
  138. specialization.m_maxNodeDispatchGrid.z());
  139. }
  140. // Create state obj
  141. ANKI_D3D_CHECK(getDevice().CreateStateObject(stateObj, IID_PPV_ARGS(&m_workGraph.m_stateObject)));
  142. // Create misc
  143. ComPtr<ID3D12StateObjectProperties1> spSOProps;
  144. ANKI_D3D_CHECK(m_workGraph.m_stateObject->QueryInterface(IID_PPV_ARGS(&spSOProps)));
  145. m_workGraph.m_progIdentifier = spSOProps->GetProgramIdentifier(wgName);
  146. ComPtr<ID3D12WorkGraphProperties> spWGProps;
  147. ANKI_D3D_CHECK(m_workGraph.m_stateObject->QueryInterface(IID_PPV_ARGS(&spWGProps)));
  148. const UINT wgIndex = spWGProps->GetWorkGraphIndex(wgName);
  149. D3D12_WORK_GRAPH_MEMORY_REQUIREMENTS memReqs;
  150. spWGProps->GetWorkGraphMemoryRequirements(wgIndex, &memReqs);
  151. ANKI_ASSERT(spWGProps->GetNumEntrypoints(wgIndex) == 1);
  152. m_workGraphScratchBufferSize = memReqs.MaxSizeInBytes;
  153. }
  154. // Get shader sizes and a few other things
  155. for(const ShaderPtr& s : m_shaders)
  156. {
  157. if(!s.isCreated())
  158. {
  159. continue;
  160. }
  161. const ShaderType type = s->getShaderType();
  162. const U32 size = s->getShaderBinarySize();
  163. m_shaderBinarySizes[type] = size;
  164. }
  165. // Misc
  166. if(isGraphicsProg)
  167. {
  168. m_graphics.m_pipelineFactory = anki::newInstance<GraphicsPipelineFactory>(GrMemoryPool::getSingleton());
  169. }
  170. return Error::kNone;
  171. }
  172. } // end namespace anki