Shader.cpp 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Atom/RPI.Public/Shader/Shader.h>
  9. #include <Atom/RHI/Factory.h>
  10. #include <Atom/RHI/PipelineStateCache.h>
  11. #include <Atom/RHI/RHISystemInterface.h>
  12. #include <AtomCore/Instance/InstanceDatabase.h>
  13. #include <Atom/RPI.Public/Shader/ShaderReloadDebugTracker.h>
  14. #include <Atom/RPI.Public/Shader/ShaderSystemInterface.h>
  15. #include <Atom/RPI.Public/Shader/ShaderResourceGroup.h>
  16. #include <AzCore/Interface/Interface.h>
  17. #include <AzCore/std/time.h>
  18. #include <AzCore/Component/TickBus.h>
  19. #define PSOCacheVersion 0 // Bump this if you want to reset PSO cache for everyone
  20. namespace AZ
  21. {
  22. namespace RPI
  23. {
  24. Data::Instance<Shader> Shader::FindOrCreate(const Data::Asset<ShaderAsset>& shaderAsset, const Name& supervariantName)
  25. {
  26. auto anySupervariantName = AZStd::any(supervariantName);
  27. // retrieve the supervariant index from the shader asset
  28. SupervariantIndex supervariantIndex = shaderAsset->GetSupervariantIndex(supervariantName);
  29. if (!supervariantIndex.IsValid())
  30. {
  31. AZ_Error("Shader", false, "Supervariant with name %s, was not found in shader %s", supervariantName.GetCStr(), shaderAsset->GetName().GetCStr());
  32. return nullptr;
  33. }
  34. // Create the instance ID using the shader asset with an additional unique identifier from the Super variant index.
  35. const Data::InstanceId instanceId =
  36. Data::InstanceId::CreateFromAsset(shaderAsset, { supervariantIndex.GetIndex() });
  37. // retrieve the shader instance from the Instance database
  38. return Data::InstanceDatabase<Shader>::Instance().FindOrCreate(instanceId, shaderAsset, &anySupervariantName);
  39. }
  40. Data::Instance<Shader> Shader::FindOrCreate(const Data::Asset<ShaderAsset>& shaderAsset)
  41. {
  42. return FindOrCreate(shaderAsset, AZ::Name{ "" });
  43. }
  44. Data::Instance<Shader> Shader::CreateInternal([[maybe_unused]] ShaderAsset& shaderAsset, const AZStd::any* anySupervariantName)
  45. {
  46. AZ_Assert(anySupervariantName != nullptr, "Invalid supervariant name param");
  47. auto supervariantName = AZStd::any_cast<AZ::Name>(*anySupervariantName);
  48. auto supervariantIndex = shaderAsset.GetSupervariantIndex(supervariantName);
  49. if (!supervariantIndex.IsValid())
  50. {
  51. AZ_Error("Shader", false, "Supervariant with name %s, was not found in shader %s", supervariantName.GetCStr(), shaderAsset.GetName().GetCStr());
  52. return nullptr;
  53. }
  54. Data::Instance<Shader> shader = aznew Shader(supervariantIndex);
  55. const RHI::ResultCode resultCode = shader->Init(shaderAsset);
  56. if (resultCode != RHI::ResultCode::Success)
  57. {
  58. return nullptr;
  59. }
  60. return shader;
  61. }
  62. Shader::~Shader()
  63. {
  64. Shutdown();
  65. }
  66. static bool GetPipelineLibraryPath(char* pipelineLibraryPath, size_t pipelineLibraryPathLength, const ShaderAsset& shaderAsset)
  67. {
  68. if (auto* fileIOBase = IO::FileIOBase::GetInstance())
  69. {
  70. const Data::AssetId& assetId = shaderAsset.GetId();
  71. Name platformName = RHI::Factory::Get().GetName();
  72. Name shaderName = shaderAsset.GetName();
  73. AZStd::string uuidString;
  74. assetId.m_guid.ToString<AZStd::string>(uuidString, false, false);
  75. RHI::RHISystemInterface* rhiSystem = RHI::RHISystemInterface::Get();
  76. RHI::PhysicalDeviceDescriptor physicalDeviceDesc = rhiSystem->GetDevice()->GetPhysicalDevice().GetDescriptor();
  77. AZStd::string configString;
  78. if (RHI::BuildOptions::IsDebugBuild)
  79. {
  80. configString = "Debug";
  81. }
  82. else if (RHI::BuildOptions::IsProfileBuild)
  83. {
  84. configString = "Profile";
  85. }
  86. else
  87. {
  88. configString = "Release";
  89. }
  90. char pipelineLibraryPathTemp[AZ_MAX_PATH_LEN];
  91. azsnprintf(
  92. pipelineLibraryPathTemp, AZ_MAX_PATH_LEN, "@user@/Atom/PipelineStateCache_%s_%u_%u_%s_Ver_%i/%s/%s_%s_%d.bin",
  93. ToString(physicalDeviceDesc.m_vendorId).data(), physicalDeviceDesc.m_deviceId,
  94. physicalDeviceDesc.m_driverVersion, configString.data(),
  95. PSOCacheVersion, platformName.GetCStr(),
  96. shaderName.GetCStr(), uuidString.data(),
  97. assetId.m_subId);
  98. fileIOBase->ResolvePath(pipelineLibraryPathTemp, pipelineLibraryPath, pipelineLibraryPathLength);
  99. return true;
  100. }
  101. return false;
  102. }
  103. RHI::ResultCode Shader::Init(ShaderAsset& shaderAsset)
  104. {
  105. Data::AssetBus::Handler::BusDisconnect();
  106. ShaderVariantFinderNotificationBus::Handler::BusDisconnect();
  107. RHI::RHISystemInterface* rhiSystem = RHI::RHISystemInterface::Get();
  108. RHI::DrawListTagRegistry* drawListTagRegistry = rhiSystem->GetDrawListTagRegistry();
  109. m_asset = { &shaderAsset, AZ::Data::AssetLoadBehavior::PreLoad };
  110. m_pipelineStateType = shaderAsset.GetPipelineStateType();
  111. GetPipelineLibraryPath(m_pipelineLibraryPath, AZ_MAX_PATH_LEN, *m_asset);
  112. {
  113. AZStd::unique_lock<decltype(m_variantCacheMutex)> lock(m_variantCacheMutex);
  114. m_shaderVariants.clear();
  115. }
  116. auto rootShaderVariantAsset = shaderAsset.GetRootVariantAsset(m_supervariantIndex);
  117. m_rootVariant.Init(m_asset, rootShaderVariantAsset, m_supervariantIndex);
  118. if (m_pipelineLibraryHandle.IsNull())
  119. {
  120. // We set up a pipeline library only once for the lifetime of the Shader instance.
  121. // This should allow the Shader to be reloaded at runtime many times, and cache and reuse PipelineState objects rather than rebuild them.
  122. // It also fixes a particular TDR crash that occurred on some hardware when hot-reloading shaders and building pipeline states
  123. // in a new pipeline library every time.
  124. RHI::PipelineStateCache* pipelineStateCache = rhiSystem->GetPipelineStateCache();
  125. ConstPtr<RHI::PipelineLibraryData> serializedData = LoadPipelineLibrary();
  126. RHI::PipelineLibraryHandle pipelineLibraryHandle = pipelineStateCache->CreateLibrary(serializedData.get(), m_pipelineLibraryPath);
  127. if (pipelineLibraryHandle.IsNull())
  128. {
  129. AZ_Error("Shader", false, "Failed to create pipeline library from pipeline state cache.");
  130. return RHI::ResultCode::Fail;
  131. }
  132. m_pipelineLibraryHandle = pipelineLibraryHandle;
  133. m_pipelineStateCache = pipelineStateCache;
  134. }
  135. const Name& drawListName = shaderAsset.GetDrawListName();
  136. if (!drawListName.IsEmpty())
  137. {
  138. m_drawListTag = drawListTagRegistry->AcquireTag(drawListName);
  139. if (!m_drawListTag.IsValid())
  140. {
  141. AZ_Error("Shader", false, "Failed to acquire a DrawListTag. Entries are full.");
  142. }
  143. }
  144. ShaderVariantFinderNotificationBus::Handler::BusConnect(m_asset.GetId());
  145. Data::AssetBus::Handler::BusConnect(m_asset.GetId());
  146. return RHI::ResultCode::Success;
  147. }
  148. void Shader::Shutdown()
  149. {
  150. ShaderVariantFinderNotificationBus::Handler::BusDisconnect();
  151. Data::AssetBus::Handler::BusDisconnect();
  152. if (m_pipelineLibraryHandle.IsValid())
  153. {
  154. SavePipelineLibrary();
  155. m_pipelineStateCache->ReleaseLibrary(m_pipelineLibraryHandle);
  156. m_pipelineStateCache = nullptr;
  157. m_pipelineLibraryHandle = {};
  158. }
  159. if (m_drawListTag.IsValid())
  160. {
  161. RHI::DrawListTagRegistry* drawListTagRegistry = RHI::RHISystemInterface::Get()->GetDrawListTagRegistry();
  162. drawListTagRegistry->ReleaseTag(m_drawListTag);
  163. m_drawListTag.Reset();
  164. }
  165. }
  166. ///////////////////////////////////////////////////////////////////////
  167. // AssetBus overrides
  168. void Shader::OnAssetReloaded(Data::Asset<Data::AssetData> asset)
  169. {
  170. ShaderReloadDebugTracker::ScopedSection reloadSection("{%p}->Shader::OnAssetReloaded %s", this, asset.GetHint().c_str());
  171. m_asset = asset;
  172. if (ShaderReloadDebugTracker::IsEnabled())
  173. {
  174. AZStd::sys_time_t now = AZStd::GetTimeUTCMilliSecond();
  175. const auto shaderVariantAsset = m_asset->GetRootVariantAsset();
  176. ShaderReloadDebugTracker::Printf("{%p}->Shader::OnAssetReloaded for shader '%s' [current time %lld] found variant '%s'",
  177. this, m_asset.GetHint().c_str(), now, shaderVariantAsset.GetHint().c_str());
  178. }
  179. Init(*m_asset.Get());
  180. ShaderReloadNotificationBus::Event(asset.GetId(), &ShaderReloadNotificationBus::Events::OnShaderReinitialized, *this);
  181. }
  182. ///////////////////////////////////////////////////////////////////////
  183. ///////////////////////////////////////////////////////////////////
  184. /// ShaderVariantFinderNotificationBus overrides
  185. void Shader::OnShaderVariantAssetReady(Data::Asset<ShaderVariantAsset> shaderVariantAsset, bool isError)
  186. {
  187. ShaderReloadDebugTracker::ScopedSection reloadSection("{%p}->Shader::OnShaderVariantAssetReady %s", this, shaderVariantAsset.GetHint().c_str());
  188. AZ_Assert(shaderVariantAsset, "Reloaded ShaderVariantAsset is null");
  189. const ShaderVariantStableId stableId = shaderVariantAsset->GetStableId();
  190. // check the supervariantIndex of the ShaderVariantAsset to make sure it matches the supervariantIndex of this shader instance
  191. if (shaderVariantAsset->GetSupervariantIndex() != m_supervariantIndex.GetIndex())
  192. {
  193. return;
  194. }
  195. // We make a copy of the updated variant because OnShaderVariantReinitialized must not be called inside
  196. // m_variantCacheMutex or deadlocks may occur.
  197. // Or if there is an error, we leave this object in its default state to indicate there was an error.
  198. // [GFX TODO] We really should have a dedicated message/event for this, but that will be covered by a future task where
  199. // we will merge ShaderReloadNotificationBus messages into one. For now, we just indicate the error by passing an empty ShaderVariant,
  200. // all our call sites don't use this data anyway.
  201. ShaderVariant updatedVariant;
  202. if (isError)
  203. {
  204. //Remark: We do not assert if the stableId == RootShaderVariantStableId, because we can not trust in the asset data
  205. //on error. so it is possible that on error the stbleId == RootShaderVariantStableId;
  206. if (stableId == RootShaderVariantStableId)
  207. {
  208. return;
  209. }
  210. AZStd::unique_lock<decltype(m_variantCacheMutex)> lock(m_variantCacheMutex);
  211. m_shaderVariants.erase(stableId);
  212. }
  213. else
  214. {
  215. AZ_Assert(stableId != RootShaderVariantStableId,
  216. "The root variant is expected to be updated by the ShaderAsset.");
  217. AZStd::unique_lock<decltype(m_variantCacheMutex)> lock(m_variantCacheMutex);
  218. auto iter = m_shaderVariants.find(stableId);
  219. if (iter != m_shaderVariants.end())
  220. {
  221. ShaderVariant& shaderVariant = iter->second;
  222. if (!shaderVariant.Init(m_asset, shaderVariantAsset, m_supervariantIndex))
  223. {
  224. AZ_Error("Shader", false, "Failed to init shaderVariant with StableId=%u", shaderVariantAsset->GetStableId());
  225. m_shaderVariants.erase(stableId);
  226. }
  227. else
  228. {
  229. updatedVariant = shaderVariant;
  230. }
  231. }
  232. else
  233. {
  234. //This is the first time the shader variant asset comes to life.
  235. updatedVariant.Init(m_asset, shaderVariantAsset, m_supervariantIndex);
  236. m_shaderVariants.emplace(stableId, updatedVariant);
  237. }
  238. }
  239. // [GFX TODO] It might make more sense to call OnShaderReinitialized here
  240. ShaderReloadNotificationBus::Event(m_asset.GetId(), &ShaderReloadNotificationBus::Events::OnShaderVariantReinitialized, updatedVariant);
  241. }
  242. ///////////////////////////////////////////////////////////////////
  243. ConstPtr<RHI::PipelineLibraryData> Shader::LoadPipelineLibrary() const
  244. {
  245. RHI::Device* device = RHI::RHISystemInterface::Get()->GetDevice();
  246. //Check if explicit file load/save operation is needed as the RHI backend api may not support it
  247. if (m_pipelineLibraryPath[0] != 0 && device->GetFeatures().m_isPsoCacheFileOperationsNeeded)
  248. {
  249. return Utils::LoadObjectFromFile<RHI::PipelineLibraryData>(m_pipelineLibraryPath);
  250. }
  251. return nullptr;
  252. }
  253. void Shader::SavePipelineLibrary() const
  254. {
  255. RHI::Device* device = RHI::RHISystemInterface::Get()->GetDevice();
  256. if (m_pipelineLibraryPath[0] != 0)
  257. {
  258. RHI::ConstPtr<RHI::PipelineLibrary> pipelineLib = m_pipelineStateCache->GetMergedLibrary(m_pipelineLibraryHandle);
  259. if(!pipelineLib)
  260. {
  261. return;
  262. }
  263. //Check if explicit file load/save operation is needed as the RHI backend api may not support it
  264. if (device->GetFeatures().m_isPsoCacheFileOperationsNeeded)
  265. {
  266. RHI::ConstPtr<RHI::PipelineLibraryData> serializedData = pipelineLib->GetSerializedData();
  267. if(serializedData)
  268. {
  269. Utils::SaveObjectToFile<RHI::PipelineLibraryData>(m_pipelineLibraryPath, DataStream::ST_BINARY, serializedData.get());
  270. }
  271. }
  272. else
  273. {
  274. [[maybe_unused]] bool result = pipelineLib->SaveSerializedData(m_pipelineLibraryPath);
  275. AZ_Error("Shader", result, "Pipeline Library %s was not saved", &m_pipelineLibraryPath);
  276. }
  277. }
  278. }
  279. ShaderOptionGroup Shader::CreateShaderOptionGroup() const
  280. {
  281. return ShaderOptionGroup(m_asset->GetShaderOptionGroupLayout());
  282. }
  283. const ShaderVariant& Shader::GetVariant(const ShaderVariantId& shaderVariantId)
  284. {
  285. Data::Asset<ShaderVariantAsset> shaderVariantAsset = m_asset->GetVariantAsset(shaderVariantId, m_supervariantIndex);
  286. if (!shaderVariantAsset || shaderVariantAsset->IsRootVariant())
  287. {
  288. return m_rootVariant;
  289. }
  290. return GetVariant(shaderVariantAsset->GetStableId());
  291. }
  292. const ShaderVariant& Shader::GetRootVariant()
  293. {
  294. return m_rootVariant;
  295. }
  296. const ShaderVariant& Shader::GetDefaultVariant()
  297. {
  298. ShaderOptionGroup defaultOptions = GetDefaultShaderOptions();
  299. return GetVariant(defaultOptions.GetShaderVariantId());
  300. }
  301. ShaderOptionGroup Shader::GetDefaultShaderOptions() const
  302. {
  303. return m_asset->GetDefaultShaderOptions();
  304. }
  305. ShaderVariantSearchResult Shader::FindVariantStableId(const ShaderVariantId& shaderVariantId) const
  306. {
  307. ShaderVariantSearchResult variantSearchResult = m_asset->FindVariantStableId(shaderVariantId);
  308. return variantSearchResult;
  309. }
  310. const ShaderVariant& Shader::GetVariant(ShaderVariantStableId shaderVariantStableId)
  311. {
  312. const ShaderVariant& variant = GetVariantInternal(shaderVariantStableId);
  313. if (ShaderReloadDebugTracker::IsEnabled())
  314. {
  315. AZStd::sys_time_t now = AZStd::GetTimeUTCMilliSecond();
  316. ShaderReloadDebugTracker::Printf("{%p}->Shader::GetVariant for shader '%s' [current time %lld] found variant '%s'",
  317. this, m_asset.GetHint().c_str(), now, variant.GetShaderVariantAsset().GetHint().c_str());
  318. }
  319. return variant;
  320. }
  321. const ShaderVariant& Shader::GetVariantInternal(ShaderVariantStableId shaderVariantStableId)
  322. {
  323. if (!shaderVariantStableId.IsValid() || shaderVariantStableId == ShaderAsset::RootShaderVariantStableId)
  324. {
  325. return m_rootVariant;
  326. }
  327. {
  328. AZStd::shared_lock<decltype(m_variantCacheMutex)> lock(m_variantCacheMutex);
  329. auto findIt = m_shaderVariants.find(shaderVariantStableId);
  330. if (findIt != m_shaderVariants.end())
  331. {
  332. return findIt->second;
  333. }
  334. }
  335. // By calling GetVariant, an asynchronous asset load request is enqueued if the variant
  336. // is not fully ready.
  337. Data::Asset<ShaderVariantAsset> shaderVariantAsset = m_asset->GetVariantAsset(shaderVariantStableId, m_supervariantIndex);
  338. if (!shaderVariantAsset || shaderVariantAsset == m_asset->GetRootVariantAsset())
  339. {
  340. // Return the root variant when the requested variant is not ready.
  341. return m_rootVariant;
  342. }
  343. AZStd::unique_lock<decltype(m_variantCacheMutex)> lock(m_variantCacheMutex);
  344. // For performance reasons We are breaking this function into two locking steps.
  345. // which means We must check again if the variant is already in the cache.
  346. auto findIt = m_shaderVariants.find(shaderVariantStableId);
  347. if (findIt != m_shaderVariants.end())
  348. {
  349. return findIt->second;
  350. }
  351. ShaderVariant newVariant;
  352. newVariant.Init(m_asset, shaderVariantAsset, m_supervariantIndex);
  353. m_shaderVariants.emplace(shaderVariantStableId, newVariant);
  354. return m_shaderVariants.at(shaderVariantStableId);
  355. }
  356. RHI::PipelineStateType Shader::GetPipelineStateType() const
  357. {
  358. return m_pipelineStateType;
  359. }
  360. const ShaderInputContract& Shader::GetInputContract() const
  361. {
  362. return m_asset->GetInputContract(m_supervariantIndex);
  363. }
  364. const ShaderOutputContract& Shader::GetOutputContract() const
  365. {
  366. return m_asset->GetOutputContract(m_supervariantIndex);
  367. }
  368. const RHI::PipelineState* Shader::AcquirePipelineState(const RHI::PipelineStateDescriptor& descriptor) const
  369. {
  370. return m_pipelineStateCache->AcquirePipelineState(m_pipelineLibraryHandle, descriptor);
  371. }
  372. const RHI::Ptr<RHI::ShaderResourceGroupLayout>& Shader::FindShaderResourceGroupLayout(const Name& shaderResourceGroupName) const
  373. {
  374. return m_asset->FindShaderResourceGroupLayout(shaderResourceGroupName, m_supervariantIndex);
  375. }
  376. const RHI::Ptr<RHI::ShaderResourceGroupLayout>& Shader::FindShaderResourceGroupLayout(uint32_t bindingSlot) const
  377. {
  378. return m_asset->FindShaderResourceGroupLayout(bindingSlot, m_supervariantIndex);
  379. }
  380. const RHI::Ptr<RHI::ShaderResourceGroupLayout>& Shader::FindFallbackShaderResourceGroupLayout() const
  381. {
  382. return m_asset->FindFallbackShaderResourceGroupLayout(m_supervariantIndex);
  383. }
  384. AZStd::span<const RHI::Ptr<RHI::ShaderResourceGroupLayout>> Shader::GetShaderResourceGroupLayouts() const
  385. {
  386. return m_asset->GetShaderResourceGroupLayouts(m_supervariantIndex);
  387. }
  388. Data::Instance<ShaderResourceGroup> Shader::CreateDrawSrgForShaderVariant(const ShaderOptionGroup& shaderOptions, bool compileTheSrg)
  389. {
  390. RHI::Ptr<RHI::ShaderResourceGroupLayout> drawSrgLayout = m_asset->GetDrawSrgLayout(GetSupervariantIndex());
  391. Data::Instance<ShaderResourceGroup> drawSrg;
  392. if (drawSrgLayout)
  393. {
  394. drawSrg = RPI::ShaderResourceGroup::Create(m_asset, GetSupervariantIndex(), drawSrgLayout->GetName());
  395. if (drawSrgLayout->HasShaderVariantKeyFallbackEntry())
  396. {
  397. drawSrg->SetShaderVariantKeyFallbackValue(shaderOptions.GetShaderVariantKeyFallbackValue());
  398. }
  399. if (compileTheSrg)
  400. {
  401. drawSrg->Compile();
  402. }
  403. }
  404. return drawSrg;
  405. }
  406. Data::Instance<ShaderResourceGroup> Shader::CreateDefaultDrawSrg(bool compileTheSrg)
  407. {
  408. return CreateDrawSrgForShaderVariant(m_asset->GetDefaultShaderOptions(), compileTheSrg);
  409. }
  410. const Data::Asset<ShaderAsset>& Shader::GetAsset() const
  411. {
  412. return m_asset;
  413. }
  414. RHI::DrawListTag Shader::GetDrawListTag() const
  415. {
  416. return m_drawListTag;
  417. }
  418. } // namespace RPI
  419. } // namespace AZ