ShaderVariantTreeAssetCreator.cpp 13 KB


  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Atom/RPI.Edit/Shader/ShaderVariantTreeAssetCreator.h>
  9. #include <Atom/RPI.Reflect/Shader/ShaderOptionGroup.h>
  10. #include <Atom/RPI.Reflect/Shader/ShaderAsset.h>
  11. namespace AZ
  12. {
  13. namespace RPI
  14. {
  15. // Arbitrary number to be reviewed that is used to constrain the range of options.
  16. static constexpr uint32_t MaxShaderVariantValues = 1000;
  17. AZ::Outcome<void, AZStd::string> ShaderVariantTreeAssetCreator::ValidateStableIdsAreUnique(const AZStd::vector<ShaderVariantListSourceData::VariantInfo>& shaderVariantList)
  18. {
  19. AZStd::unordered_map<ShaderVariantStableId, uint32_t> stableIdToIndexMap;
  20. stableIdToIndexMap.reserve(shaderVariantList.size());
  21. uint32_t sourceVariantIndex = 0;
  22. for (const ShaderVariantListSourceData::VariantInfo& variantInfo : shaderVariantList)
  23. {
  24. const ShaderVariantStableId variantInfoStableId{variantInfo.m_stableId};
  25. if (variantInfoStableId.IsNull() || variantInfoStableId == RootShaderVariantStableId)
  26. {
  27. return AZ::Failure(AZStd::string::format("The variant at index=[%u] has StableId=[%u], which is forbidden.", sourceVariantIndex, variantInfoStableId.GetIndex()));
  28. }
  29. if (stableIdToIndexMap.find(variantInfoStableId) != stableIdToIndexMap.end())
  30. {
  31. const uint32_t existingVariantIndex = stableIdToIndexMap.at(variantInfoStableId);
  32. return AZ::Failure(AZStd::string::format("The variant at index=[%u] is trying to use StableId=[%u] which is already taken by variant at index=[%u]"
  33. , sourceVariantIndex, variantInfoStableId.GetIndex(), existingVariantIndex));
  34. }
  35. stableIdToIndexMap.emplace(variantInfoStableId, sourceVariantIndex);
  36. sourceVariantIndex++;
  37. }
  38. return AZ::Success();
  39. }
  40. void ShaderVariantTreeAssetCreator::Begin(const AZ::Data::AssetId& assetId)
  41. {
  42. BeginCommon(assetId);
  43. }
  44. void ShaderVariantTreeAssetCreator::SetShaderOptionGroupLayout(const RPI::ShaderOptionGroupLayout& shaderOptionGroupLayout)
  45. {
  46. if (ValidateIsReady())
  47. {
  48. m_shaderOptionGroupLayout = &shaderOptionGroupLayout;
  49. }
  50. }
  51. void ShaderVariantTreeAssetCreator::SetVariantInfos(const AZStd::vector<ShaderVariantListSourceData::VariantInfo>& variantInfos)
  52. {
  53. if (ValidateIsReady())
  54. {
  55. // Add +1 space for the root variant.
  56. m_variantInfos.reserve(variantInfos.size() + 1);
  57. // When building the tree it'll be important that the first variant in the list
  58. // is the root variant.
  59. m_variantInfos.push_back(ShaderVariantListSourceData::VariantInfo());
  60. for (const auto& variantInfo : variantInfos)
  61. {
  62. m_variantInfos.push_back(variantInfo);
  63. }
  64. }
  65. }
  66. //! Finalizes and assigns ownership of the asset to result, if successful.
  67. //! Otherwise false is returned and result is left untouched.
  68. bool ShaderVariantTreeAssetCreator::End(Data::Asset<ShaderVariantTreeAsset>& result)
  69. {
  70. if (!ValidateIsReady())
  71. {
  72. return false;
  73. }
  74. if (!m_shaderOptionGroupLayout)
  75. {
  76. ReportError("No ShaderOptionGroupLayout has been set. Failed to finalize the ShaderVariantTreeAsset.");
  77. return false;
  78. }
  79. if (m_variantInfos.size() == 0)
  80. {
  81. ReportError("The list of source variants is not valid. Failed to finalize the ShaderVariantTreeAsset.");
  82. return false;
  83. }
  84. if (!EndInternal(result))
  85. {
  86. return false;
  87. }
  88. if (!m_asset->FinalizeAfterLoad())
  89. {
  90. ReportError("Failed to finalize the ShaderVariantTreeAsset.");
  91. return false;
  92. }
  93. m_asset->SetReady();
  94. return EndCommon(result);
  95. }
  96. bool ShaderVariantTreeAssetCreator::EndInternal([[maybe_unused]] Data::Asset<ShaderVariantTreeAsset>& result)
  97. {
  98. // Temporary structure used for sorting and caching intermediate results
  99. struct OptionCache
  100. {
  101. AZ::Name m_optionName;
  102. AZ::Name m_valueName;
  103. RPI::ShaderOptionIndex m_optionIndex; // Cached m_optionName
  104. RPI::ShaderOptionValue m_value; // Cached m_valueName
  105. };
  106. AZStd::vector<OptionCache> optionList;
  107. // We can not have more options than the number of options in the layout:
  108. optionList.reserve(m_shaderOptionGroupLayout->GetShaderOptionCount());
  109. //Build the list of ShaderVariantId.
  110. AZStd::vector<ShaderVariantIdWithStableId> shaderVariantIds;
  111. shaderVariantIds.reserve(m_variantInfos.size());
  112. for (const ShaderVariantListSourceData::VariantInfo& variantInfo : m_variantInfos)
  113. {
  114. // Variants have their own set of option values so we rebuild the list for each variant:
  115. optionList.clear();
  116. // This loop will validate and cache the indices for each option value:
  117. for (const auto& shaderOption : variantInfo.m_options)
  118. {
  119. Name optionName{ shaderOption.first };
  120. Name optionValue{ shaderOption.second };
  121. auto optionIndex = m_shaderOptionGroupLayout->FindShaderOptionIndex(optionName);
  122. if (optionIndex.IsNull())
  123. {
  124. ReportError("Invalid shader option: %s", optionName.GetCStr());
  125. continue;
  126. }
  127. auto option = m_shaderOptionGroupLayout->GetShaderOption(optionIndex);
  128. auto value = option.FindValue(optionValue);
  129. if (value.IsNull())
  130. {
  131. ReportError("Invalid value (%s) for shader option: %s", optionValue.GetCStr(), optionName.GetCStr());
  132. continue;
  133. }
  134. optionList.push_back(OptionCache{ optionName, optionValue, optionIndex, value });
  135. }
  136. // The user might supply the option values in any order. Sort them now:
  137. AZStd::sort(optionList.begin(), optionList.end()
  138. , [](const OptionCache& left, const OptionCache& right)
  139. {
  140. // m_optionIndex is the cached index in the m_options vector (stored in the ShaderOptionGroupLayout)
  141. // m_options has already been sorted so the index *is* the option priority:
  142. return left.m_optionIndex < right.m_optionIndex;
  143. }
  144. );
  145. RPI::ShaderOptionGroup optionGroup(m_shaderOptionGroupLayout);
  146. for (const auto& optionCache : optionList)
  147. {
  148. auto option = m_shaderOptionGroupLayout->GetShaderOption(optionCache.m_optionIndex);
  149. // Assign the option value specified in the variant:
  150. option.Set(optionGroup, optionCache.m_value);
  151. }
  152. shaderVariantIds.push_back({optionGroup.GetShaderVariantId(), ShaderVariantStableId{variantInfo.m_stableId}});
  153. }
  154. return BuildTree(shaderVariantIds);
  155. }
  156. bool ShaderVariantTreeAssetCreator::BuildTree(const AZStd::vector<ShaderVariantIdWithStableId>& shaderVariantIdsWithStableId)
  157. {
  158. //! Helper struct to build a dynamically allocated tree. The tree is then serialized into an accelerated search structure
  159. struct TreeNode
  160. {
  161. ShaderVariantStableId m_variantStableId;
  162. AZStd::vector<AZStd::shared_ptr<TreeNode>> m_children;
  163. TreeNode()
  164. : m_variantStableId(ShaderVariantStableId{ ShaderVariantTreeAsset::UnspecifiedIndex })
  165. {
  166. }
  167. TreeNode(const ShaderVariantStableId& variantStableId)
  168. : m_variantStableId(variantStableId)
  169. {
  170. }
  171. //! Bakes a node into the variant search tree.
  172. //! position The position in the flat vector array of the tree where the node needs to be baked
  173. //! nextFree The position in the flat vector array of the tree where the next free nodes can start from
  174. //! node The node to bake.
  175. //! tree The tree to bake into. If null, the node is not baked and the number of nodes is returned.
  176. static uint32_t BuildNode(uint32_t position, uint32_t nextFree, TreeNode* node, ShaderVariantTreeAsset* tree = nullptr)
  177. {
  178. AZ_Assert(position < nextFree, "Invalid position for the current node");
  179. const uint32_t offsetToChildren = node->m_children.empty() ? 0 : nextFree - position;
  180. uint32_t childIndex = nextFree;
  181. nextFree += aznumeric_cast<uint32_t>(node->m_children.size());
  182. for (const auto& child : node->m_children)
  183. {
  184. if (child)
  185. {
  186. nextFree = BuildNode(childIndex, nextFree, child.get(), tree);
  187. }
  188. childIndex++;
  189. }
  190. if (tree)
  191. {
  192. (*tree).SetNode(position, ShaderVariantTreeNode{ node->m_variantStableId, offsetToChildren });
  193. node->m_children.clear();
  194. }
  195. return nextFree;
  196. }
  197. };
  198. const auto& options = m_shaderOptionGroupLayout->GetShaderOptions();
  199. // The first variant is always the root.
  200. auto treeRoot = AZStd::make_unique<TreeNode>();
  201. treeRoot->m_variantStableId = ShaderAsset::RootShaderVariantStableId;
  202. // We start from the next variant after the root.
  203. for (uint32_t variantIndex = 1u; variantIndex < shaderVariantIdsWithStableId.size(); variantIndex++)
  204. {
  205. const ShaderVariantIdWithStableId shaderVariantIdWithStableId = shaderVariantIdsWithStableId[variantIndex];
  206. auto optionValues = ShaderVariantTreeAsset::ConvertToValueChain(m_shaderOptionGroupLayout, shaderVariantIdWithStableId.m_shaderVariantId);
  207. auto treeNode = treeRoot.get();
  208. for (uint32_t optionIndex = 0; optionIndex < optionValues.size(); optionIndex++)
  209. {
  210. const uint32_t optionValue = optionValues[optionIndex];
  211. const ShaderOptionDescriptor& option = options[optionIndex];
  212. // Validation for unsupported features of the variant tree:
  213. // - Large range of integers
  214. // - Enums with gaps in their values
  215. if (option.GetValuesCount() > MaxShaderVariantValues)
  216. {
  217. ReportError("Large integer ranges are not supported.");
  218. continue;
  219. }
  220. if (option.GetMaxValue().GetIndex() - option.GetMinValue().GetIndex() + 1 != option.GetValuesCount())
  221. {
  222. ReportError("Enums with gaps are not supported.");
  223. continue;
  224. }
  225. // The first time we add all the children.
  226. if (treeNode->m_children.empty())
  227. {
  228. treeNode->m_children.resize(option.GetValuesCount() + 1, nullptr);
  229. }
  230. // If the child node at the correct index is still missing, create it.
  231. if (treeNode->m_children[optionValue + 1] == nullptr)
  232. {
  233. // The variant index of a non-leaf node is invalid.
  234. treeNode->m_children[optionValue + 1] = AZStd::make_shared<TreeNode>();
  235. }
  236. // Visit the next node.
  237. treeNode = treeNode->m_children[optionValue + 1].get();
  238. }
  239. // Set the variant index for the current node.
  240. treeNode->m_variantStableId = ShaderVariantStableId{ shaderVariantIdWithStableId.m_stableId };
  241. }
  242. // Calculate the total size of the tree, and construct it.
  243. const uint32_t treeSize = TreeNode::BuildNode(0, 1, treeRoot.get(), nullptr);
  244. m_asset->m_nodes =
  245. AZStd::vector<ShaderVariantTreeNode>(treeSize, ShaderVariantTreeNode());
  246. TreeNode::BuildNode(0, 1, treeRoot.get(), m_asset.Get());
  247. return true;
  248. }
  249. } // namespace RPI
  250. } // namespace AZ