ShaderCollection.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <AtomCore/std/containers/vector_set.h>
  9. #include <AzCore/Asset/AssetSerializer.h>
  10. #include <Atom/RPI.Reflect/Material/ShaderCollection.h>
  11. #include <Atom/RHI/RHISystemInterface.h>
  12. #include <Atom/RHI/DrawListTagRegistry.h>
  13. namespace AZ
  14. {
  15. namespace RPI
  16. {
  17. //! This allows ShaderCollection::Item to serialize only a ShaderVariantId rather than the ShaderOptionsGroup object,
  18. //! but still provide the corresponding ShaderOptionsGroup for use at runtime.
  19. //! RenderStates will be modified at runtime as well. It will be merged into the RenderStates stored in the corresponding ShaderVariant.
  20. class ShaderVariantReferenceSerializationEvents
  21. : public SerializeContext::IEventHandler
  22. {
  23. //! Called right before we start reading from the instance pointed by classPtr.
  24. void OnReadBegin(void* classPtr) override
  25. {
  26. ShaderCollection::Item* shaderVariantReference = reinterpret_cast<ShaderCollection::Item*>(classPtr);
  27. shaderVariantReference->m_shaderVariantId = shaderVariantReference->m_shaderOptionGroup.GetShaderVariantId();
  28. }
  29. //! Called right after we finish writing data to the instance pointed at by classPtr.
  30. void OnWriteEnd(void* classPtr) override
  31. {
  32. ShaderCollection::Item* shaderVariantReference = reinterpret_cast<ShaderCollection::Item*>(classPtr);
  33. if (shaderVariantReference->m_shaderAsset.IsReady())
  34. {
  35. shaderVariantReference->m_shaderOptionGroup = ShaderOptionGroup{
  36. shaderVariantReference->m_shaderAsset->GetShaderOptionGroupLayout(),
  37. shaderVariantReference->m_shaderVariantId
  38. };
  39. }
  40. else
  41. {
  42. // No worries, eventually the Material::Init will end up
  43. // calling InitializeShaderOptionGroup() and @m_shaderOptionGroup
  44. // will get proper data.
  45. shaderVariantReference->m_shaderOptionGroup = {};
  46. shaderVariantReference->m_shaderAsset.QueueLoad(); // Not necessary to call QueueLoad, but doesn't hurt either.
  47. }
  48. }
  49. };
  50. void ShaderCollection::Reflect(AZ::ReflectContext* context)
  51. {
  52. ShaderCollection::Item::Reflect(context);
  53. NameReflectionMapForIndex::Reflect(context);
  54. if (auto* serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
  55. {
  56. serializeContext->Class<ShaderCollection>()
  57. ->Version(5)
  58. ->Field("ShaderItems", &ShaderCollection::m_shaderItems)
  59. ->Field("ShaderTagIndexMap", &ShaderCollection::m_shaderTagIndexMap)
  60. ;
  61. }
  62. }
  63. void ShaderCollection::Item::Reflect(AZ::ReflectContext* context)
  64. {
  65. if (auto* serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
  66. {
  67. serializeContext->Enum<ShaderCollection::Item::DrawItemType>()
  68. ->Value("Raster", ShaderCollection::Item::DrawItemType::Raster)
  69. ->Value("Dispatch", ShaderCollection::Item::DrawItemType::Dispatch)
  70. ->Value("Deferred", ShaderCollection::Item::DrawItemType::Deferred)
  71. ->Value("RayTracing", ShaderCollection::Item::DrawItemType::RayTracing)
  72. ->Value("Custom", ShaderCollection::Item::DrawItemType::Custom);
  73. serializeContext->Class<ShaderCollection::Item>()
  74. ->Version(7)
  75. ->EventHandler<ShaderVariantReferenceSerializationEvents>()
  76. ->Field("ShaderAsset", &ShaderCollection::Item::m_shaderAsset)
  77. ->Field("ShaderVariantId", &ShaderCollection::Item::m_shaderVariantId)
  78. ->Field("Enabled", &ShaderCollection::Item::m_enabled)
  79. ->Field("OwnedShaderOptionIndices", &ShaderCollection::Item::m_ownedShaderOptionIndices)
  80. ->Field("ShaderTag", &ShaderCollection::Item::m_shaderTag)
  81. ->Field("DrawItemType", &ShaderCollection::Item::m_drawItemType);
  82. }
  83. if (BehaviorContext* behaviorContext = azrtti_cast<BehaviorContext*>(context))
  84. {
  85. behaviorContext->Class<Item>("ShaderCollectionItem")
  86. ->Attribute(AZ::Script::Attributes::Scope, AZ::Script::Attributes::ScopeFlags::Automation)
  87. ->Attribute(AZ::Script::Attributes::Category, "Shader")
  88. ->Attribute(AZ::Script::Attributes::Module, "shader")
  89. ->Method("GetShaderAsset", &Item::GetShaderAsset)
  90. ->Method("GetShaderAssetId", &Item::GetShaderAssetId)
  91. ->Method("GetShaderVariantId", &Item::GetShaderVariantId)
  92. ->Method("GetShaderOptionGroup", &Item::GetShaderOptionGroup)
  93. ->Method("GetDrawItemType", &Item::GetDrawItemType)
  94. ->Method("MaterialOwnsShaderOption", static_cast<bool (Item::*)(const Name&) const>(&Item::MaterialOwnsShaderOption));
  95. }
  96. }
  97. size_t ShaderCollection::size() const
  98. {
  99. return m_shaderItems.size();
  100. }
  101. ShaderCollection::iterator ShaderCollection::begin()
  102. {
  103. return m_shaderItems.begin();
  104. }
  105. ShaderCollection::const_iterator ShaderCollection::begin() const
  106. {
  107. return m_shaderItems.begin();
  108. }
  109. ShaderCollection::iterator ShaderCollection::end()
  110. {
  111. return m_shaderItems.end();
  112. }
  113. ShaderCollection::const_iterator ShaderCollection::end() const
  114. {
  115. return m_shaderItems.end();
  116. }
  117. ShaderCollection::Item::Item()
  118. : m_renderStatesOverlay(RHI::GetInvalidRenderStates())
  119. {
  120. }
  121. ShaderCollection::Item& ShaderCollection::operator[](size_t i)
  122. {
  123. return m_shaderItems[i];
  124. }
  125. const ShaderCollection::Item& ShaderCollection::operator[](size_t i) const
  126. {
  127. return m_shaderItems[i];
  128. }
  129. ShaderCollection::Item& ShaderCollection::operator[](const AZ::Name& shaderTag)
  130. {
  131. return m_shaderItems[m_shaderTagIndexMap.Find(shaderTag).GetIndex()];
  132. }
  133. const ShaderCollection::Item& ShaderCollection::operator[](const AZ::Name& shaderTag) const
  134. {
  135. return m_shaderItems[m_shaderTagIndexMap.Find(shaderTag).GetIndex()];
  136. }
  137. bool ShaderCollection::HasShaderTag(const AZ::Name& shaderTag) const
  138. {
  139. return (m_shaderTagIndexMap.Find(shaderTag).IsValid());
  140. }
  141. void ShaderCollection::TryReplaceShaderAsset(const Data::Asset<ShaderAsset>& newShaderAsset)
  142. {
  143. for (auto& shaderItem : m_shaderItems)
  144. {
  145. shaderItem.TryReplaceShaderAsset(newShaderAsset);
  146. }
  147. }
  148. bool ShaderCollection::InitializeShaderOptionGroups()
  149. {
  150. for (auto& shaderItem : m_shaderItems)
  151. {
  152. if (!shaderItem.InitializeShaderOptionGroup())
  153. {
  154. return false;
  155. }
  156. }
  157. return true;
  158. }
  159. ShaderCollection::Item::Item(
  160. const Data::Asset<ShaderAsset>& shaderAsset, const AZ::Name& shaderTag, DrawItemType drawItemType, ShaderVariantId variantId)
  161. : m_renderStatesOverlay(RHI::GetInvalidRenderStates())
  162. , m_shaderAsset(shaderAsset)
  163. , m_shaderVariantId(variantId)
  164. , m_shaderTag(shaderTag)
  165. , m_drawItemType(drawItemType)
  166. , m_shaderOptionGroup(shaderAsset->GetShaderOptionGroupLayout(), variantId)
  167. {
  168. }
  169. ShaderCollection::Item::Item(
  170. Data::Asset<ShaderAsset>&& shaderAsset, const AZ::Name& shaderTag, DrawItemType drawItemType, ShaderVariantId variantId)
  171. : m_renderStatesOverlay(RHI::GetInvalidRenderStates())
  172. , m_shaderAsset(AZStd::move(shaderAsset))
  173. , m_shaderVariantId(variantId)
  174. , m_shaderTag(shaderTag)
  175. , m_drawItemType(drawItemType)
  176. , m_shaderOptionGroup(shaderAsset->GetShaderOptionGroupLayout(), variantId)
  177. {
  178. }
  179. const Data::Asset<ShaderAsset>& ShaderCollection::Item::GetShaderAsset() const
  180. {
  181. return m_shaderAsset;
  182. }
  183. const ShaderVariantId& ShaderCollection::Item::GetShaderVariantId() const
  184. {
  185. return m_shaderOptionGroup.GetShaderVariantId();
  186. }
  187. const ShaderOptionGroup* ShaderCollection::Item::GetShaderOptions() const
  188. {
  189. return &m_shaderOptionGroup;
  190. }
  191. ShaderOptionGroup* ShaderCollection::Item::GetShaderOptions()
  192. {
  193. return &m_shaderOptionGroup;
  194. }
  195. bool ShaderCollection::Item::MaterialOwnsShaderOption(const AZ::Name& shaderOptionName) const
  196. {
  197. return m_ownedShaderOptionIndices.contains(m_shaderOptionGroup.FindShaderOptionIndex(shaderOptionName));
  198. }
  199. bool ShaderCollection::Item::MaterialOwnsShaderOption(ShaderOptionIndex shaderOptionIndex) const
  200. {
  201. return m_ownedShaderOptionIndices.contains(shaderOptionIndex);
  202. }
  203. const RHI::RenderStates* ShaderCollection::Item::GetRenderStatesOverlay() const
  204. {
  205. return &m_renderStatesOverlay;
  206. }
  207. RHI::RenderStates* ShaderCollection::Item::GetRenderStatesOverlay()
  208. {
  209. return &m_renderStatesOverlay;
  210. }
  211. RHI::DrawListTag ShaderCollection::Item::GetDrawListTagOverride() const
  212. {
  213. return m_drawListTagOverride;
  214. }
  215. void ShaderCollection::Item::SetDrawListTagOverride(RHI::DrawListTag drawList)
  216. {
  217. m_drawListTagOverride = drawList;
  218. }
  219. void ShaderCollection::Item::SetDrawListTagOverride(const AZ::Name& drawListName)
  220. {
  221. if (drawListName.IsEmpty())
  222. {
  223. m_drawListTagOverride.Reset();
  224. return;
  225. }
  226. RHI::DrawListTagRegistry* drawListTagRegistry = RHI::RHISystemInterface::Get()->GetDrawListTagRegistry();
  227. // Note: we should use FindTag instead of AcquireTag to avoid occupy DrawListTag entries.
  228. RHI::DrawListTag newTag = drawListTagRegistry->FindTag(drawListName);
  229. if (newTag.IsNull())
  230. {
  231. AZ_Error("ShaderCollection", false, "Failed to set draw list with name: %s.", drawListName.GetCStr());
  232. return;
  233. }
  234. m_drawListTagOverride = newTag;
  235. }
  236. void ShaderCollection::Item::SetEnabled(bool enabled)
  237. {
  238. m_enabled = enabled;
  239. }
  240. bool ShaderCollection::Item::IsEnabled() const
  241. {
  242. return m_enabled;
  243. }
  244. const AZ::Name& ShaderCollection::Item::GetShaderTag() const
  245. {
  246. return m_shaderTag;
  247. }
  248. const Data::AssetId& ShaderCollection::Item::GetShaderAssetId() const
  249. {
  250. return m_shaderAsset->GetId();
  251. }
  252. const AZ::RPI::ShaderOptionGroup& ShaderCollection::Item::GetShaderOptionGroup() const
  253. {
  254. return m_shaderOptionGroup;
  255. }
  256. bool ShaderCollection::Item::InitializeShaderOptionGroup()
  257. {
  258. if (!m_shaderAsset.IsReady())
  259. {
  260. return false;
  261. }
  262. m_shaderOptionGroup = ShaderOptionGroup{
  263. m_shaderAsset->GetShaderOptionGroupLayout(),
  264. m_shaderVariantId };
  265. return true;
  266. }
  267. void ShaderCollection::Item::TryReplaceShaderAsset(const Data::Asset<ShaderAsset>& newShaderAsset)
  268. {
  269. if (newShaderAsset.GetId() != m_shaderAsset.GetId())
  270. {
  271. return;
  272. }
  273. m_shaderAsset = newShaderAsset;
  274. [[maybe_unused]] bool success = InitializeShaderOptionGroup();
  275. AZ_Assert(success, "Failed to InitializeShaderOptionGroup using shaderAsset with uuid=%s and hint=%s"
  276. , newShaderAsset.GetId().ToFixedString().c_str(), newShaderAsset.GetHint().c_str());
  277. }
  278. } // namespace RPI
  279. } // namespace AZ