Shader.cpp 22 KB


  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Atom/RPI.Public/Shader/Shader.h>
  9. #include <Atom/RHI/Factory.h>
  10. #include <Atom/RHI/PipelineStateCache.h>
  11. #include <Atom/RHI/RHISystemInterface.h>
  12. #include <AtomCore/Instance/InstanceDatabase.h>
  13. #include <Atom/RPI.Public/Shader/ShaderReloadDebugTracker.h>
  14. #include <Atom/RPI.Public/Shader/ShaderSystemInterface.h>
  15. #include <Atom/RPI.Public/Shader/ShaderResourceGroup.h>
  16. #include <AzCore/Interface/Interface.h>
  17. #include <AzCore/std/time.h>
  18. #include <AzCore/Component/TickBus.h>
  19. #define PSOCacheVersion 0 // Bump this if you want to reset PSO cache for everyone
  20. namespace AZ
  21. {
  22. namespace RPI
  23. {
  24. Data::Instance<Shader> Shader::FindOrCreate(const Data::Asset<ShaderAsset>& shaderAsset, const Name& supervariantName)
  25. {
  26. auto anySupervariantName = AZStd::any(supervariantName);
  27. // retrieve the supervariant index from the shader asset
  28. SupervariantIndex supervariantIndex = shaderAsset->GetSupervariantIndex(supervariantName);
  29. if (!supervariantIndex.IsValid())
  30. {
  31. AZ_Error("Shader", false, "Supervariant with name %s, was not found in shader %s", supervariantName.GetCStr(), shaderAsset->GetName().GetCStr());
  32. return nullptr;
  33. }
  34. // Create the instance ID using the shader asset with an additional unique identifier from the Super variant index.
  35. const Data::InstanceId instanceId =
  36. Data::InstanceId::CreateFromAsset(shaderAsset, { supervariantIndex.GetIndex() });
  37. // retrieve the shader instance from the Instance database
  38. return Data::InstanceDatabase<Shader>::Instance().FindOrCreate(instanceId, shaderAsset, &anySupervariantName);
  39. }
  40. Data::Instance<Shader> Shader::FindOrCreate(const Data::Asset<ShaderAsset>& shaderAsset)
  41. {
  42. return FindOrCreate(shaderAsset, AZ::Name{ "" });
  43. }
  44. Data::Instance<Shader> Shader::CreateInternal([[maybe_unused]] ShaderAsset& shaderAsset, const AZStd::any* anySupervariantName)
  45. {
  46. AZ_Assert(anySupervariantName != nullptr, "Invalid supervariant name param");
  47. auto supervariantName = AZStd::any_cast<AZ::Name>(*anySupervariantName);
  48. auto supervariantIndex = shaderAsset.GetSupervariantIndex(supervariantName);
  49. if (!supervariantIndex.IsValid())
  50. {
  51. AZ_Error("Shader", false, "Supervariant with name %s, was not found in shader %s", supervariantName.GetCStr(), shaderAsset.GetName().GetCStr());
  52. return nullptr;
  53. }
  54. Data::Instance<Shader> shader = aznew Shader(supervariantIndex);
  55. const RHI::ResultCode resultCode = shader->Init(shaderAsset);
  56. if (resultCode != RHI::ResultCode::Success)
  57. {
  58. return nullptr;
  59. }
  60. return shader;
  61. }
  62. Shader::~Shader()
  63. {
  64. Shutdown();
  65. }
  66. static bool GetPipelineLibraryPath(char* pipelineLibraryPath, size_t pipelineLibraryPathLength, const ShaderAsset& shaderAsset)
  67. {
  68. if (auto* fileIOBase = IO::FileIOBase::GetInstance())
  69. {
  70. const Data::AssetId& assetId = shaderAsset.GetId();
  71. Name platformName = RHI::Factory::Get().GetName();
  72. Name shaderName = shaderAsset.GetName();
  73. AZStd::string uuidString;
  74. assetId.m_guid.ToString<AZStd::string>(uuidString, false, false);
  75. RHI::RHISystemInterface* rhiSystem = RHI::RHISystemInterface::Get();
  76. RHI::PhysicalDeviceDescriptor physicalDeviceDesc = rhiSystem->GetDevice()->GetPhysicalDevice().GetDescriptor();
  77. AZStd::string configString;
  78. if (RHI::BuildOptions::IsDebugBuild)
  79. {
  80. configString = "Debug";
  81. }
  82. else if (RHI::BuildOptions::IsProfileBuild)
  83. {
  84. configString = "Profile";
  85. }
  86. else
  87. {
  88. configString = "Release";
  89. }
  90. char pipelineLibraryPathTemp[AZ_MAX_PATH_LEN];
  91. azsnprintf(
  92. pipelineLibraryPathTemp, AZ_MAX_PATH_LEN, "@user@/Atom/PipelineStateCache_%s_%u_%u_%s_Ver_%i/%s/%s_%s_%d.bin",
  93. ToString(physicalDeviceDesc.m_vendorId).data(), physicalDeviceDesc.m_deviceId,
  94. physicalDeviceDesc.m_driverVersion, configString.data(),
  95. PSOCacheVersion, platformName.GetCStr(),
  96. shaderName.GetCStr(), uuidString.data(),
  97. assetId.m_subId);
  98. fileIOBase->ResolvePath(pipelineLibraryPathTemp, pipelineLibraryPath, pipelineLibraryPathLength);
  99. return true;
  100. }
  101. return false;
  102. }
  103. RHI::ResultCode Shader::Init(ShaderAsset& shaderAsset)
  104. {
  105. Data::AssetBus::Handler::BusDisconnect();
  106. ShaderVariantFinderNotificationBus::Handler::BusDisconnect();
  107. RHI::RHISystemInterface* rhiSystem = RHI::RHISystemInterface::Get();
  108. RHI::DrawListTagRegistry* drawListTagRegistry = rhiSystem->GetDrawListTagRegistry();
  109. m_asset = { &shaderAsset, AZ::Data::AssetLoadBehavior::PreLoad };
  110. m_pipelineStateType = shaderAsset.GetPipelineStateType();
  111. GetPipelineLibraryPath(m_pipelineLibraryPath, AZ_MAX_PATH_LEN, *m_asset);
  112. {
  113. AZStd::unique_lock<decltype(m_variantCacheMutex)> lock(m_variantCacheMutex);
  114. m_shaderVariants.clear();
  115. }
  116. auto rootShaderVariantAsset = shaderAsset.GetRootVariantAsset(m_supervariantIndex);
  117. m_rootVariant.Init(m_asset, rootShaderVariantAsset, m_supervariantIndex);
  118. if (m_pipelineLibraryHandle.IsNull())
  119. {
  120. // We set up a pipeline library only once for the lifetime of the Shader instance.
  121. // This should allow the Shader to be reloaded at runtime many times, and cache and reuse PipelineState objects rather than rebuild them.
  122. // It also fixes a particular TDR crash that occurred on some hardware when hot-reloading shaders and building pipeline states
  123. // in a new pipeline library every time.
  124. RHI::PipelineStateCache* pipelineStateCache = rhiSystem->GetPipelineStateCache();
  125. ConstPtr<RHI::PipelineLibraryData> serializedData = LoadPipelineLibrary();
  126. RHI::PipelineLibraryHandle pipelineLibraryHandle = pipelineStateCache->CreateLibrary(serializedData.get(), m_pipelineLibraryPath);
  127. if (pipelineLibraryHandle.IsNull())
  128. {
  129. AZ_Error("Shader", false, "Failed to create pipeline library from pipeline state cache.");
  130. return RHI::ResultCode::Fail;
  131. }
  132. m_pipelineLibraryHandle = pipelineLibraryHandle;
  133. m_pipelineStateCache = pipelineStateCache;
  134. }
  135. const Name& drawListName = shaderAsset.GetDrawListName();
  136. if (!drawListName.IsEmpty())
  137. {
  138. m_drawListTag = drawListTagRegistry->AcquireTag(drawListName);
  139. if (!m_drawListTag.IsValid())
  140. {
  141. AZ_Error("Shader", false, "Failed to acquire a DrawListTag. Entries are full.");
  142. }
  143. }
  144. ShaderVariantFinderNotificationBus::Handler::BusConnect(m_asset.GetId());
  145. Data::AssetBus::Handler::BusConnect(m_asset.GetId());
  146. return RHI::ResultCode::Success;
  147. }
  148. void Shader::Shutdown()
  149. {
  150. ShaderVariantFinderNotificationBus::Handler::BusDisconnect();
  151. Data::AssetBus::Handler::BusDisconnect();
  152. if (m_pipelineLibraryHandle.IsValid())
  153. {
  154. if (r_enablePsoCaching)
  155. {
  156. SavePipelineLibrary();
  157. }
  158. m_pipelineStateCache->ReleaseLibrary(m_pipelineLibraryHandle);
  159. m_pipelineStateCache = nullptr;
  160. m_pipelineLibraryHandle = {};
  161. }
  162. if (m_drawListTag.IsValid())
  163. {
  164. RHI::DrawListTagRegistry* drawListTagRegistry = RHI::RHISystemInterface::Get()->GetDrawListTagRegistry();
  165. drawListTagRegistry->ReleaseTag(m_drawListTag);
  166. m_drawListTag.Reset();
  167. }
  168. }
  169. ///////////////////////////////////////////////////////////////////////
  170. // AssetBus overrides
  171. void Shader::OnAssetReloaded(Data::Asset<Data::AssetData> asset)
  172. {
  173. ShaderReloadDebugTracker::ScopedSection reloadSection("{%p}->Shader::OnAssetReloaded %s", this, asset.GetHint().c_str());
  174. m_asset = asset;
  175. if (ShaderReloadDebugTracker::IsEnabled())
  176. {
  177. AZStd::sys_time_t now = AZStd::GetTimeUTCMilliSecond();
  178. const auto shaderVariantAsset = m_asset->GetRootVariantAsset();
  179. ShaderReloadDebugTracker::Printf("{%p}->Shader::OnAssetReloaded for shader '%s' [current time %lld] found variant '%s'",
  180. this, m_asset.GetHint().c_str(), now, shaderVariantAsset.GetHint().c_str());
  181. }
  182. Init(*m_asset.Get());
  183. ShaderReloadNotificationBus::Event(asset.GetId(), &ShaderReloadNotificationBus::Events::OnShaderReinitialized, *this);
  184. }
  185. ///////////////////////////////////////////////////////////////////////
  186. ///////////////////////////////////////////////////////////////////
  187. /// ShaderVariantFinderNotificationBus overrides
  188. void Shader::OnShaderVariantAssetReady(Data::Asset<ShaderVariantAsset> shaderVariantAsset, bool isError)
  189. {
  190. ShaderReloadDebugTracker::ScopedSection reloadSection("{%p}->Shader::OnShaderVariantAssetReady %s", this, shaderVariantAsset.GetHint().c_str());
  191. AZ_Assert(shaderVariantAsset, "Reloaded ShaderVariantAsset is null");
  192. const ShaderVariantStableId stableId = shaderVariantAsset->GetStableId();
  193. // check the supervariantIndex of the ShaderVariantAsset to make sure it matches the supervariantIndex of this shader instance
  194. if (shaderVariantAsset->GetSupervariantIndex() != m_supervariantIndex.GetIndex())
  195. {
  196. return;
  197. }
  198. // We make a copy of the updated variant because OnShaderVariantReinitialized must not be called inside
  199. // m_variantCacheMutex or deadlocks may occur.
  200. // Or if there is an error, we leave this object in its default state to indicate there was an error.
  201. // [GFX TODO] We really should have a dedicated message/event for this, but that will be covered by a future task where
  202. // we will merge ShaderReloadNotificationBus messages into one. For now, we just indicate the error by passing an empty ShaderVariant,
  203. // all our call sites don't use this data anyway.
  204. ShaderVariant updatedVariant;
  205. if (isError)
  206. {
  207. //Remark: We do not assert if the stableId == RootShaderVariantStableId, because we can not trust in the asset data
  208. //on error. so it is possible that on error the stbleId == RootShaderVariantStableId;
  209. if (stableId == RootShaderVariantStableId)
  210. {
  211. return;
  212. }
  213. AZStd::unique_lock<decltype(m_variantCacheMutex)> lock(m_variantCacheMutex);
  214. m_shaderVariants.erase(stableId);
  215. }
  216. else
  217. {
  218. AZ_Assert(stableId != RootShaderVariantStableId,
  219. "The root variant is expected to be updated by the ShaderAsset.");
  220. AZStd::unique_lock<decltype(m_variantCacheMutex)> lock(m_variantCacheMutex);
  221. auto iter = m_shaderVariants.find(stableId);
  222. if (iter != m_shaderVariants.end())
  223. {
  224. ShaderVariant& shaderVariant = iter->second;
  225. if (!shaderVariant.Init(m_asset, shaderVariantAsset, m_supervariantIndex))
  226. {
  227. AZ_Error("Shader", false, "Failed to init shaderVariant with StableId=%u", shaderVariantAsset->GetStableId());
  228. m_shaderVariants.erase(stableId);
  229. }
  230. else
  231. {
  232. updatedVariant = shaderVariant;
  233. }
  234. }
  235. else
  236. {
  237. //This is the first time the shader variant asset comes to life.
  238. updatedVariant.Init(m_asset, shaderVariantAsset, m_supervariantIndex);
  239. m_shaderVariants.emplace(stableId, updatedVariant);
  240. }
  241. }
  242. // [GFX TODO] It might make more sense to call OnShaderReinitialized here
  243. ShaderReloadNotificationBus::Event(m_asset.GetId(), &ShaderReloadNotificationBus::Events::OnShaderVariantReinitialized, updatedVariant);
  244. }
  245. ///////////////////////////////////////////////////////////////////
  246. ConstPtr<RHI::PipelineLibraryData> Shader::LoadPipelineLibrary() const
  247. {
  248. RHI::Device* device = RHI::RHISystemInterface::Get()->GetDevice();
  249. //Check if explicit file load/save operation is needed as the RHI backend api may not support it
  250. if (m_pipelineLibraryPath[0] != 0 && device->GetFeatures().m_isPsoCacheFileOperationsNeeded)
  251. {
  252. return Utils::LoadObjectFromFile<RHI::PipelineLibraryData>(m_pipelineLibraryPath);
  253. }
  254. return nullptr;
  255. }
  256. void Shader::SavePipelineLibrary() const
  257. {
  258. RHI::Device* device = RHI::RHISystemInterface::Get()->GetDevice();
  259. if (m_pipelineLibraryPath[0] != 0)
  260. {
  261. RHI::ConstPtr<RHI::PipelineLibrary> pipelineLib = m_pipelineStateCache->GetMergedLibrary(m_pipelineLibraryHandle);
  262. if(!pipelineLib)
  263. {
  264. return;
  265. }
  266. //Check if explicit file load/save operation is needed as the RHI backend api may not support it
  267. if (device->GetFeatures().m_isPsoCacheFileOperationsNeeded)
  268. {
  269. RHI::ConstPtr<RHI::PipelineLibraryData> serializedData = pipelineLib->GetSerializedData();
  270. if(serializedData)
  271. {
  272. Utils::SaveObjectToFile<RHI::PipelineLibraryData>(m_pipelineLibraryPath, DataStream::ST_BINARY, serializedData.get());
  273. }
  274. }
  275. else
  276. {
  277. [[maybe_unused]] bool result = pipelineLib->SaveSerializedData(m_pipelineLibraryPath);
  278. AZ_Error("Shader", result, "Pipeline Library %s was not saved", &m_pipelineLibraryPath);
  279. }
  280. }
  281. }
  282. ShaderOptionGroup Shader::CreateShaderOptionGroup() const
  283. {
  284. return ShaderOptionGroup(m_asset->GetShaderOptionGroupLayout());
  285. }
  286. const ShaderVariant& Shader::GetVariant(const ShaderVariantId& shaderVariantId)
  287. {
  288. Data::Asset<ShaderVariantAsset> shaderVariantAsset = m_asset->GetVariantAsset(shaderVariantId, m_supervariantIndex);
  289. if (!shaderVariantAsset || shaderVariantAsset->IsRootVariant())
  290. {
  291. return m_rootVariant;
  292. }
  293. return GetVariant(shaderVariantAsset->GetStableId());
  294. }
  295. const ShaderVariant& Shader::GetRootVariant()
  296. {
  297. return m_rootVariant;
  298. }
  299. const ShaderVariant& Shader::GetDefaultVariant()
  300. {
  301. ShaderOptionGroup defaultOptions = GetDefaultShaderOptions();
  302. return GetVariant(defaultOptions.GetShaderVariantId());
  303. }
  304. ShaderOptionGroup Shader::GetDefaultShaderOptions() const
  305. {
  306. return m_asset->GetDefaultShaderOptions();
  307. }
  308. ShaderVariantSearchResult Shader::FindVariantStableId(const ShaderVariantId& shaderVariantId) const
  309. {
  310. ShaderVariantSearchResult variantSearchResult = m_asset->FindVariantStableId(shaderVariantId);
  311. return variantSearchResult;
  312. }
  313. const ShaderVariant& Shader::GetVariant(ShaderVariantStableId shaderVariantStableId)
  314. {
  315. const ShaderVariant& variant = GetVariantInternal(shaderVariantStableId);
  316. if (ShaderReloadDebugTracker::IsEnabled())
  317. {
  318. AZStd::sys_time_t now = AZStd::GetTimeUTCMilliSecond();
  319. ShaderReloadDebugTracker::Printf("{%p}->Shader::GetVariant for shader '%s' [current time %lld] found variant '%s'",
  320. this, m_asset.GetHint().c_str(), now, variant.GetShaderVariantAsset().GetHint().c_str());
  321. }
  322. return variant;
  323. }
  324. const ShaderVariant& Shader::GetVariantInternal(ShaderVariantStableId shaderVariantStableId)
  325. {
  326. if (!shaderVariantStableId.IsValid() || shaderVariantStableId == ShaderAsset::RootShaderVariantStableId)
  327. {
  328. return m_rootVariant;
  329. }
  330. {
  331. AZStd::shared_lock<decltype(m_variantCacheMutex)> lock(m_variantCacheMutex);
  332. auto findIt = m_shaderVariants.find(shaderVariantStableId);
  333. if (findIt != m_shaderVariants.end())
  334. {
  335. return findIt->second;
  336. }
  337. }
  338. // By calling GetVariant, an asynchronous asset load request is enqueued if the variant
  339. // is not fully ready.
  340. Data::Asset<ShaderVariantAsset> shaderVariantAsset = m_asset->GetVariantAsset(shaderVariantStableId, m_supervariantIndex);
  341. if (!shaderVariantAsset || shaderVariantAsset == m_asset->GetRootVariantAsset())
  342. {
  343. // Return the root variant when the requested variant is not ready.
  344. return m_rootVariant;
  345. }
  346. AZStd::unique_lock<decltype(m_variantCacheMutex)> lock(m_variantCacheMutex);
  347. // For performance reasons We are breaking this function into two locking steps.
  348. // which means We must check again if the variant is already in the cache.
  349. auto findIt = m_shaderVariants.find(shaderVariantStableId);
  350. if (findIt != m_shaderVariants.end())
  351. {
  352. return findIt->second;
  353. }
  354. ShaderVariant newVariant;
  355. newVariant.Init(m_asset, shaderVariantAsset, m_supervariantIndex);
  356. m_shaderVariants.emplace(shaderVariantStableId, newVariant);
  357. return m_shaderVariants.at(shaderVariantStableId);
  358. }
  359. RHI::PipelineStateType Shader::GetPipelineStateType() const
  360. {
  361. return m_pipelineStateType;
  362. }
  363. const ShaderInputContract& Shader::GetInputContract() const
  364. {
  365. return m_asset->GetInputContract(m_supervariantIndex);
  366. }
  367. const ShaderOutputContract& Shader::GetOutputContract() const
  368. {
  369. return m_asset->GetOutputContract(m_supervariantIndex);
  370. }
  371. const RHI::PipelineState* Shader::AcquirePipelineState(const RHI::PipelineStateDescriptor& descriptor) const
  372. {
  373. return m_pipelineStateCache->AcquirePipelineState(m_pipelineLibraryHandle, descriptor, m_asset->GetName());
  374. }
  375. const RHI::Ptr<RHI::ShaderResourceGroupLayout>& Shader::FindShaderResourceGroupLayout(const Name& shaderResourceGroupName) const
  376. {
  377. return m_asset->FindShaderResourceGroupLayout(shaderResourceGroupName, m_supervariantIndex);
  378. }
  379. const RHI::Ptr<RHI::ShaderResourceGroupLayout>& Shader::FindShaderResourceGroupLayout(uint32_t bindingSlot) const
  380. {
  381. return m_asset->FindShaderResourceGroupLayout(bindingSlot, m_supervariantIndex);
  382. }
  383. const RHI::Ptr<RHI::ShaderResourceGroupLayout>& Shader::FindFallbackShaderResourceGroupLayout() const
  384. {
  385. return m_asset->FindFallbackShaderResourceGroupLayout(m_supervariantIndex);
  386. }
  387. AZStd::span<const RHI::Ptr<RHI::ShaderResourceGroupLayout>> Shader::GetShaderResourceGroupLayouts() const
  388. {
  389. return m_asset->GetShaderResourceGroupLayouts(m_supervariantIndex);
  390. }
  391. Data::Instance<ShaderResourceGroup> Shader::CreateDrawSrgForShaderVariant(const ShaderOptionGroup& shaderOptions, bool compileTheSrg)
  392. {
  393. RHI::Ptr<RHI::ShaderResourceGroupLayout> drawSrgLayout = m_asset->GetDrawSrgLayout(GetSupervariantIndex());
  394. Data::Instance<ShaderResourceGroup> drawSrg;
  395. if (drawSrgLayout)
  396. {
  397. drawSrg = RPI::ShaderResourceGroup::Create(m_asset, GetSupervariantIndex(), drawSrgLayout->GetName());
  398. if (drawSrgLayout->HasShaderVariantKeyFallbackEntry())
  399. {
  400. drawSrg->SetShaderVariantKeyFallbackValue(shaderOptions.GetShaderVariantKeyFallbackValue());
  401. }
  402. if (compileTheSrg)
  403. {
  404. drawSrg->Compile();
  405. }
  406. }
  407. return drawSrg;
  408. }
  409. Data::Instance<ShaderResourceGroup> Shader::CreateDefaultDrawSrg(bool compileTheSrg)
  410. {
  411. return CreateDrawSrgForShaderVariant(m_asset->GetDefaultShaderOptions(), compileTheSrg);
  412. }
  413. const Data::Asset<ShaderAsset>& Shader::GetAsset() const
  414. {
  415. return m_asset;
  416. }
  417. RHI::DrawListTag Shader::GetDrawListTag() const
  418. {
  419. return m_drawListTag;
  420. }
  421. } // namespace RPI
  422. } // namespace AZ