ShaderPlatformInterface.cpp 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <RHI.Builders/ShaderPlatformInterface.h>
  9. #include <Atom/RHI.Edit/Utils.h>
  10. #include <Atom/RHI.Reflect/DX12/PipelineLayoutDescriptor.h>
  11. #include <Atom/RHI.Reflect/DX12/ShaderStageFunction.h>
  12. #include <Atom/RHI/RHIUtils.h>
  13. #include <AzCore/IO/FileIO.h>
  14. #include <AzCore/IO/SystemFile.h>
  15. #include <AzCore/Serialization/Json/JsonUtils.h>
  16. #include <AzFramework/StringFunc/StringFunc.h>
  17. namespace AZ
  18. {
  19. namespace DX12
  20. {
  21. static const char* DX12ApiName = "dx12";
  22. static const char* DX12ShaderPlatformName = "DX12ShaderPlatform";
  23. static const char* PlatformShaderHeader = "Builders/ShaderHeaders/Platform/Windows/DX12/PlatformHeader.hlsli";
  24. static const char* AzslShaderHeader = "Builders/ShaderHeaders/Platform/Windows/DX12/AzslcHeader.azsli";
  25. ShaderPlatformInterface::ShaderPlatformInterface(uint32_t apiUniqueIndex)
  26. : RHI::ShaderPlatformInterface(apiUniqueIndex), m_apiName{ DX12ApiName }
  27. {
  28. }
  29. RHI::APIType ShaderPlatformInterface::GetAPIType() const
  30. {
  31. return RHI::APIType{ DX12ApiName };
  32. }
  33. AZ::Name ShaderPlatformInterface::GetAPIName() const
  34. {
  35. return m_apiName;
  36. }
  37. RHI::Ptr<RHI::ShaderStageFunction> ShaderPlatformInterface::CreateShaderStageFunction(const StageDescriptor& stageDescriptor)
  38. {
  39. RHI::Ptr<ShaderStageFunction> newShaderStageFunction = ShaderStageFunction::Create(RHI::ToRHIShaderStage(stageDescriptor.m_stageType));
  40. const auto& byteCode = stageDescriptor.m_byteCode;
  41. const int byteCodeIndex = 0;
  42. newShaderStageFunction->SetByteCode(byteCodeIndex, byteCode);
  43. // Read the json data with the specialization constants offsets.
  44. // If the shader was not compiled with specialization constants this attribute will be empty.
  45. AZStd::string fileName;
  46. if (!stageDescriptor.m_extraData.empty())
  47. {
  48. auto jsonOutcome = JsonSerializationUtils::ReadJsonFile(stageDescriptor.m_extraData);
  49. if (!jsonOutcome.IsSuccess())
  50. {
  51. AZ_Error(DX12ShaderPlatformName, false, "%s", jsonOutcome.GetError().c_str());
  52. return nullptr;
  53. }
  54. const rapidjson::Document& doc = jsonOutcome.GetValue();
  55. ShaderStageFunction::SpecializationOffsets offsets;
  56. for (auto itr = doc.MemberBegin(); itr != doc.MemberEnd(); ++itr)
  57. {
  58. if (!AZ::StringFunc::LooksLikeInt(itr->name.GetString()))
  59. {
  60. AZ_Error(DX12ShaderPlatformName, false, "SpecializationId %s is not an Int", itr->name.GetString());
  61. continue;
  62. }
  63. uint32_t specializationId = static_cast<uint32_t>(AZ::StringFunc::ToInt(itr->name.GetString()));
  64. uint32_t offset = itr->value.GetUint();
  65. offsets[specializationId] = offset;
  66. }
  67. newShaderStageFunction->SetSpecializationOffsets(byteCodeIndex, offsets);
  68. }
  69. newShaderStageFunction->Finalize();
  70. return newShaderStageFunction;
  71. }
  72. bool ShaderPlatformInterface::IsShaderStageForRaster(RHI::ShaderHardwareStage shaderStageType) const
  73. {
  74. bool hasRasterProgram = false;
  75. hasRasterProgram |= shaderStageType == RHI::ShaderHardwareStage::Vertex;
  76. hasRasterProgram |= shaderStageType == RHI::ShaderHardwareStage::Fragment;
  77. hasRasterProgram |= shaderStageType == RHI::ShaderHardwareStage::Geometry;
  78. return hasRasterProgram;
  79. }
  80. bool ShaderPlatformInterface::IsShaderStageForCompute(RHI::ShaderHardwareStage shaderStageType) const
  81. {
  82. return (shaderStageType == RHI::ShaderHardwareStage::Compute);
  83. }
  84. bool ShaderPlatformInterface::IsShaderStageForRayTracing(RHI::ShaderHardwareStage shaderStageType) const
  85. {
  86. return (shaderStageType == RHI::ShaderHardwareStage::RayTracing);
  87. }
  88. RHI::Ptr<RHI::PipelineLayoutDescriptor> ShaderPlatformInterface::CreatePipelineLayoutDescriptor()
  89. {
  90. return PipelineLayoutDescriptor::Create();
  91. }
  92. bool ShaderPlatformInterface::BuildPipelineLayoutDescriptor(
  93. RHI::Ptr<RHI::PipelineLayoutDescriptor> pipelineLayoutDescriptorBase,
  94. const ShaderResourceGroupInfoList& srgInfoList,
  95. const RootConstantsInfo& rootConstantsInfo,
  96. const RHI::ShaderBuildArguments& shaderBuildArguments)
  97. {
  98. PipelineLayoutDescriptor* pipelineLayoutDescriptor = azrtti_cast<PipelineLayoutDescriptor*>(pipelineLayoutDescriptorBase.get());
  99. AZ_Assert(pipelineLayoutDescriptor, "PipelineLayoutDescriptor should have been created by now");
  100. for (const ShaderResourceGroupInfo& srgInfo : srgInfoList)
  101. {
  102. ShaderResourceGroupVisibility srgVisibility;
  103. // Copy the resources binding info so we can erase the static samplers
  104. // while adding them to the m_staticSamplersShaderStageMask list.
  105. // Each static sampler has it's own visibility. All other resources share the same visibility mask.
  106. auto resourcesBindingInfo = srgInfo.m_bindingInfo.m_resourcesRegisterMap;
  107. for (const RHI::ShaderInputStaticSamplerDescriptor& staticSamplerDescriptor : srgInfo.m_layout->GetStaticSamplers())
  108. {
  109. auto findIt = resourcesBindingInfo.find(staticSamplerDescriptor.m_name);
  110. if (findIt != resourcesBindingInfo.end())
  111. {
  112. // Erase the static sampler from the resource list so we don't use it when calculating
  113. // the descriptor table shader stage mask.
  114. resourcesBindingInfo.erase(findIt);
  115. }
  116. else
  117. {
  118. AZ_Error(DX12ShaderPlatformName, false, "Could not find binding info for static sampler '%s'", staticSamplerDescriptor.m_name.GetCStr());
  119. return false;
  120. }
  121. }
  122. const bool dxcDisableOptimizations = RHI::ShaderBuildArguments::HasArgument(shaderBuildArguments.m_dxcArguments, "-Od");
  123. if (dxcDisableOptimizations)
  124. {
  125. // When optimizations are disabled (-Od), all resources declared in the source file are available to all stages
  126. // (when enabled only the resources which are referenced in a stage are bound to the stage)
  127. srgVisibility.m_descriptorTableShaderStageMask = RHI::ShaderStageMask::All;
  128. }
  129. else
  130. {
  131. for (const auto& bindInfo : resourcesBindingInfo)
  132. {
  133. srgVisibility.m_descriptorTableShaderStageMask |= bindInfo.second.m_shaderStageMask;
  134. }
  135. srgVisibility.m_descriptorTableShaderStageMask |= srgInfo.m_bindingInfo.m_constantDataBindingInfo.m_shaderStageMask;
  136. }
  137. pipelineLayoutDescriptor->AddShaderResourceGroupVisibility(srgVisibility);
  138. if (rootConstantsInfo.m_totalSizeInBytes > 0)
  139. {
  140. AZ_Assert((rootConstantsInfo.m_totalSizeInBytes % 4) == 0, "Inline constant size is not a multiple of 32 bit");
  141. pipelineLayoutDescriptor->SetRootConstantBinding(RootConstantBinding{ rootConstantsInfo.m_totalSizeInBytes / 4, rootConstantsInfo.m_registerId, rootConstantsInfo.m_spaceId });
  142. }
  143. }
  144. return pipelineLayoutDescriptor->Finalize() == RHI::ResultCode::Success;
  145. }
  146. bool ShaderPlatformInterface::CompilePlatformInternal(
  147. [[maybe_unused]] const AssetBuilderSDK::PlatformInfo& platform,
  148. const AZStd::string& shaderSourcePath,
  149. const AZStd::string& functionName,
  150. RHI::ShaderHardwareStage shaderStage,
  151. const AZStd::string& tempFolderPath,
  152. StageDescriptor& outputDescriptor,
  153. const RHI::ShaderBuildArguments& shaderBuildArguments,
  154. const bool useSpecializationConstants) const
  155. {
  156. AZStd::vector<uint8_t> shaderByteCode;
  157. AZStd::string specializationOffsetsFile;
  158. // Compile HLSL shader to byte code
  159. bool compiledSucessfully = CompileHLSLShader(
  160. shaderSourcePath, // shader source filepath
  161. tempFolderPath, // AP job temp folder
  162. functionName, // name of function that is the entry point
  163. shaderStage, // shader stage (vertex shader, pixel shader, ...)
  164. shaderBuildArguments,
  165. shaderByteCode, // compiled shader output
  166. outputDescriptor.m_byProducts, // dynamic branch count output & byproduct files
  167. specializationOffsetsFile, // path to the json file with the specialization offsets
  168. useSpecializationConstants); // if the shader stage it's using specialization constants
  169. if (!compiledSucessfully)
  170. {
  171. AZ_Error(DX12ShaderPlatformName, false, "Failed to compile HLSL shader");
  172. return false;
  173. }
  174. const char byteCodeHeader[] = { 'D', 'X', 'B', 'C' };
  175. if (shaderByteCode.size() > sizeof(byteCodeHeader) && memcmp(shaderByteCode.data(), byteCodeHeader, sizeof(byteCodeHeader)) == 0)
  176. {
  177. outputDescriptor.m_stageType = shaderStage;
  178. outputDescriptor.m_byteCode = AZStd::move(shaderByteCode);
  179. outputDescriptor.m_extraData = AZStd::move(specializationOffsetsFile);
  180. }
  181. else
  182. {
  183. AZ_Error(DX12ShaderPlatformName, false, "Compiled shader for %s is invalid", shaderSourcePath.c_str());
  184. return false;
  185. }
  186. return true;
  187. }
  188. const char* ShaderPlatformInterface::GetAzslHeader(const AssetBuilderSDK::PlatformInfo& platform) const
  189. {
  190. AZ_UNUSED(platform);
  191. return AzslShaderHeader;
  192. }
  193. bool ShaderPlatformInterface::CompileHLSLShader(
  194. const AZStd::string& shaderSourceFile,
  195. const AZStd::string& tempFolder,
  196. const AZStd::string& entryPoint,
  197. const RHI::ShaderHardwareStage shaderStageType,
  198. const RHI::ShaderBuildArguments& shaderBuildArguments,
  199. AZStd::vector<uint8_t>& compiledShader,
  200. ByProducts& byProducts,
  201. AZStd::string& specializationOffsetsFile,
  202. const bool useSpecializationConstants) const
  203. {
  204. // Shader compiler executable
  205. const auto dxcRelativePath = RHI::GetDirectXShaderCompilerPath("Builders/DirectXShaderCompiler/dxc.exe");
  206. // NOTE:
  207. // Running DX12 on PC with DXIL shaders requires modern GPUs and at least Windows 10 Build 1803 or later for Shader Model 6.2
  208. // https://github.com/Microsoft/DirectXShaderCompiler/wiki/Running-Shaders
  209. // -Fo "Output object file"
  210. AZStd::string shaderOutputFile;
  211. AzFramework::StringFunc::Path::GetFileName(shaderSourceFile.c_str(), shaderOutputFile);
  212. AzFramework::StringFunc::Path::Join(tempFolder.c_str(), shaderOutputFile.c_str(), shaderOutputFile);
  213. AzFramework::StringFunc::Path::ReplaceExtension(shaderOutputFile, "dxil.bin");
  214. // -Fh "Output header file containing object code", used for counting dynamic branches
  215. AZStd::string objectCodeOutputFile;
  216. AzFramework::StringFunc::Path::GetFileName(shaderSourceFile.c_str(), objectCodeOutputFile);
  217. AzFramework::StringFunc::Path::Join(tempFolder.c_str(), objectCodeOutputFile.c_str(), objectCodeOutputFile);
  218. AzFramework::StringFunc::Path::ReplaceExtension(objectCodeOutputFile, "dxil.txt");
  219. // Stage profile name parameter
  220. // Note: RayTracing shaders must be compiled with version 6_3, while the rest of the stages
  221. // are compiled with version 6_2, so RayTracing cannot share the version constant.
  222. const AZStd::string shaderModelVersion = "6_2";
  223. const AZStd::unordered_map<RHI::ShaderHardwareStage, AZStd::string> stageToProfileName =
  224. {
  225. {RHI::ShaderHardwareStage::Vertex, "vs_" + shaderModelVersion},
  226. {RHI::ShaderHardwareStage::Fragment, "ps_" + shaderModelVersion},
  227. {RHI::ShaderHardwareStage::Compute, "cs_" + shaderModelVersion},
  228. {RHI::ShaderHardwareStage::Geometry, "gs_" + shaderModelVersion},
  229. {RHI::ShaderHardwareStage::RayTracing, "lib_6_3"}
  230. };
  231. auto profileIt = stageToProfileName.find(shaderStageType);
  232. if (profileIt == stageToProfileName.end())
  233. {
  234. AZ_Error(DX12ShaderPlatformName, false, "Unsupported shader stage");
  235. return false;
  236. }
  237. const bool graphicsDevMode = RHI::IsGraphicsDevModeEnabled();
  238. // Compilation parameters
  239. auto dxcArguments = shaderBuildArguments.m_dxcArguments;
  240. if (graphicsDevMode || BuildHasDebugInfo(shaderBuildArguments))
  241. {
  242. RHI::ShaderBuildArguments::AppendArguments(dxcArguments, { "-Zi", "-Zss", "-Od" });
  243. }
  244. unsigned char sha1[RHI::Sha1NumBytes];
  245. RHI::PrependArguments args;
  246. args.m_sourceFile = shaderSourceFile.c_str();
  247. args.m_prependFile = PlatformShaderHeader;
  248. args.m_destinationFolder = tempFolder.c_str();
  249. args.m_digest = &sha1;
  250. const auto dxcInputFile = RHI::PrependFile(args); // Prepend PAL header & obtain hash
  251. // -Fd "Write debug information to the given file, or automatically named file in directory when ending in '\\'"
  252. // If we use the auto-name (hash), there is no way we can retrieve that name apart from listing the directory.
  253. // Instead, let's just generate that hash ourselves.
  254. AZStd::string symbolDatabaseFileCliArgument{" "}; // when not debug: still insert a space between 5.dxil and 7.hlsl-in
  255. if (graphicsDevMode || shaderBuildArguments.m_generateDebugInfo)
  256. {
  257. // prepare .pdb filename:
  258. AZStd::string sha1hex = RHI::ByteToHexString(sha1);
  259. AZStd::string symbolDatabaseFilePath = dxcInputFile.c_str(); // mutate from source
  260. AZStd::string pdbFileName = sha1hex + "-" + profileIt->second; // concatenate the shader profile to disambiguate vs/ps...
  261. AzFramework::StringFunc::Path::ReplaceFullName(symbolDatabaseFilePath, pdbFileName.c_str(), "pdb");
  262. // it is possible that another activated platform/profile, already exported that file. (since it's hashed on the source file)
  263. // dxc returns an error in such case. we get less surprising effets by just not mentionning an -Fd argument
  264. if (AZ::IO::SystemFile::Exists(symbolDatabaseFilePath.c_str()))
  265. {
  266. AZ_Warning(DX12ShaderPlatformName, false, "debug symbol file %s already exists -> -Fd argument dropped", symbolDatabaseFilePath.c_str());
  267. }
  268. else
  269. {
  270. symbolDatabaseFileCliArgument = " -Fd \"" + symbolDatabaseFilePath + "\" "; // 6.pdb hereunder
  271. byProducts.m_intermediatePaths.emplace(AZStd::move(symbolDatabaseFilePath));
  272. }
  273. }
  274. const auto params = RHI::ShaderBuildArguments::ListAsString(dxcArguments);
  275. const auto dxcEntryPoint = (shaderStageType == RHI::ShaderHardwareStage::RayTracing) ? "" : AZStd::string::format("-E %s", entryPoint.c_str());
  276. // 1.entry 3.config 5.dxil 7.hlsl-in
  277. // | 2.SM | 4.output | 6.pdb |
  278. // | | | | | | |
  279. const auto dxcCommandOptions = AZStd::string::format("%s -T %s %s -Fo \"%s\" -Fh \"%s\"%s\"%s\"",
  280. dxcEntryPoint.c_str(), // 1
  281. profileIt->second.c_str(), // 2
  282. params.c_str(), // 3
  283. shaderOutputFile.c_str(), // 4
  284. objectCodeOutputFile.c_str(), // 5
  285. symbolDatabaseFileCliArgument.c_str(), // 6
  286. dxcInputFile.c_str() // 7
  287. );
  288. // Run Shader Compiler
  289. if (!RHI::ExecuteShaderCompiler(dxcRelativePath, dxcCommandOptions, shaderSourceFile, tempFolder, "DXC"))
  290. {
  291. return false;
  292. }
  293. if (useSpecializationConstants)
  294. {
  295. // Need to patch the shader so it can be used with specialization constants.
  296. const auto dxscRelativePath = RHI::GetDirectXShaderCompilerPath("Builders/DirectXShaderCompiler/dxsc.exe");
  297. AZStd::string shaderOutputCommon;
  298. AzFramework::StringFunc::Path::GetFileName(shaderSourceFile.c_str(), shaderOutputCommon);
  299. AzFramework::StringFunc::Path::Join(tempFolder.c_str(), shaderOutputCommon.c_str(), shaderOutputCommon);
  300. AZStd::string patchedShaderOutput = shaderOutputCommon;
  301. AzFramework::StringFunc::Path::ReplaceExtension(patchedShaderOutput, "dxil.patched.bin");
  302. AZStd::string offsetsOutput = shaderOutputCommon;
  303. AzFramework::StringFunc::Path::ReplaceExtension(offsetsOutput, "offsets.json");
  304. const auto dxscCommandOptions = AZStd::string::format(
  305. // 1.sentinel 3.offsets_output
  306. // | 2.output | 4.dxil-in
  307. // | | | |
  308. "-sv=%lu -o=\"%s\" -f=\"%s\" \"%s\"",
  309. static_cast<unsigned long>(SCSentinelValue), // 1
  310. patchedShaderOutput.c_str(), // 2
  311. offsetsOutput.c_str(), // 3
  312. shaderOutputFile.c_str() // 4
  313. );
  314. if (!RHI::ExecuteShaderCompiler(dxscRelativePath, dxscCommandOptions, shaderSourceFile, tempFolder, "DXSC"))
  315. {
  316. return false;
  317. }
  318. shaderOutputFile = patchedShaderOutput;
  319. specializationOffsetsFile = offsetsOutput;
  320. }
  321. auto shaderOutputFileLoadResult = AZ::RHI::LoadFileBytes(shaderOutputFile.c_str());
  322. if (!shaderOutputFileLoadResult)
  323. {
  324. AZ_Error(DX12ShaderPlatformName, false, "%s", shaderOutputFileLoadResult.GetError().c_str());
  325. return false;
  326. }
  327. compiledShader = shaderOutputFileLoadResult.TakeValue();
  328. // Count the dynamic branches by searching dxc.exe's generated header file.
  329. // There might be a more ideal way to count the number of dynamic branches, perhaps using DXC libs, but doing it this way is quick and easy to set up.
  330. auto objectCodeLoadResult = AZ::RHI::LoadFileString(objectCodeOutputFile.c_str());
  331. if (objectCodeLoadResult)
  332. {
  333. // The regex here is based on dxc source code, which lists terminating instructions as:
  334. // case Ret: return "ret";
  335. // case Br: return "br";
  336. // case Switch: return "switch";
  337. // case IndirectBr: return "indirectbr";
  338. // case Invoke: return "invoke";
  339. // case Resume: return "resume";
  340. // case Unreachable: return "unreachable";
  341. // If you have to update this regex, also update UtilsTests RegexCount_DXIL
  342. byProducts.m_dynamicBranchCount = aznumeric_cast<uint32_t>(AZ::RHI::RegexCount(objectCodeLoadResult.GetValue(), "^ *(br|indirectbr|switch) "));
  343. }
  344. else
  345. {
  346. byProducts.m_dynamicBranchCount = ByProducts::UnknownDynamicBranchCount;
  347. }
  348. if (graphicsDevMode || shaderBuildArguments.m_generateDebugInfo)
  349. {
  350. byProducts.m_intermediatePaths.emplace(AZStd::move(objectCodeOutputFile));
  351. }
  352. return true;
  353. }
  354. }
  355. }