3
0

ShaderResourceGroup.cpp 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Atom/RHI.Reflect/ShaderDataMappings.h>
  9. #include <Atom/RPI.Public/Shader/ShaderResourceGroup.h>
  10. #include <AtomCore/Instance/InstanceDatabase.h>
  11. namespace AZ
  12. {
  13. namespace RPI
  14. {
  15. const char* ShaderResourceGroup::s_traceCategoryName = "ShaderResourceGroup";
  16. const Data::Instance<Image> ShaderResourceGroup::s_nullImage;
  17. const Data::Instance<Buffer> ShaderResourceGroup::s_nullBuffer;
  18. Data::InstanceId ShaderResourceGroup::MakeSrgPoolInstanceId(
  19. const Data::Asset<ShaderAsset>& shaderAsset, const SupervariantIndex& supervariantIndex, const AZ::Name& srgName)
  20. {
  21. AZ_Assert(!srgName.IsEmpty(), "Invalid ShaderResourceGroup name");
  22. // Let's find the srg layout with the given name, because it contains the azsl file path of origin
  23. // which is essential to uniquely identify an SRG and avoid redundant copies in memory.
  24. auto srgLayout = shaderAsset->FindShaderResourceGroupLayout(srgName, supervariantIndex);
  25. AZ_Assert(
  26. srgLayout != nullptr,
  27. "Failed to find SRG with name %s, using supervariantIndex %u from shaderAsset %s",
  28. srgName.GetCStr(),
  29. supervariantIndex.GetIndex(),
  30. shaderAsset.GetHint().c_str());
  31. // Create the InstanceId by combining data from the SRG name and layout. This value does not need to be unique between
  32. // asset IDs because the data can be shared as long as the names and layouts match.
  33. const AZ::Uuid instanceUuid = AZ::Uuid::CreateName(srgLayout->GetUniqueId()) + AZ::Uuid::CreateName(srgName.GetStringView());
  34. // Create a union to split the 64 bit layout hash into 32 bit unsigned integers for use as instance ID sub IDs
  35. union {
  36. AZ::HashValue64 hash64;
  37. struct
  38. {
  39. uint32_t x;
  40. uint32_t y;
  41. };
  42. } hashUnion;
  43. hashUnion.hash64 = srgLayout->GetHash();
  44. // Use the supervariantIndex and layout hash as the subIds for the InstanceId
  45. return Data::InstanceId::CreateUuid(instanceUuid, { supervariantIndex.GetIndex(), hashUnion.x, hashUnion.y });
  46. }
  47. Data::Instance<ShaderResourceGroup> ShaderResourceGroup::Create(
  48. const Data::Asset<ShaderAsset>& shaderAsset, const AZ::Name& srgName)
  49. {
  50. // retrieve the supervariantIndex by searching for the default supervariant name, this will
  51. // allow the shader asset to properly handle the RPI::ShaderSystem supervariant
  52. SupervariantIndex supervariantIndex = shaderAsset->GetSupervariantIndex(AZ::Name(""));
  53. SrgInitParams initParams{ supervariantIndex, srgName };
  54. auto anyInitParams = AZStd::any(initParams);
  55. return Data::InstanceDatabase<ShaderResourceGroup>::Instance().Create(shaderAsset, &anyInitParams);
  56. }
  57. Data::Instance<ShaderResourceGroup> ShaderResourceGroup::Create(
  58. const Data::Asset<ShaderAsset>& shaderAsset, const SupervariantIndex& supervariantIndex, const AZ::Name& srgName)
  59. {
  60. SrgInitParams initParams{ supervariantIndex, srgName };
  61. auto anyInitParams = AZStd::any(initParams);
  62. return Data::InstanceDatabase<ShaderResourceGroup>::Instance().Create(shaderAsset, &anyInitParams);
  63. }
  64. Data::Instance<ShaderResourceGroup> ShaderResourceGroup::CreateInternal(ShaderAsset& shaderAsset, const AZStd::any* anySrgInitParams)
  65. {
  66. AZ_Assert(anySrgInitParams, "Invalid SrgInitParams");
  67. auto srgInitParams = AZStd::any_cast<SrgInitParams>(*anySrgInitParams);
  68. Data::Instance<ShaderResourceGroup> srg = aznew ShaderResourceGroup();
  69. const RHI::ResultCode resultCode = srg->Init(shaderAsset, srgInitParams.m_supervariantIndex, srgInitParams.m_srgName);
  70. if (resultCode != RHI::ResultCode::Success)
  71. {
  72. return nullptr;
  73. }
  74. return srg;
  75. }
  76. RHI::ResultCode ShaderResourceGroup::Init(ShaderAsset& shaderAsset, const SupervariantIndex& supervariantIndex, const AZ::Name& srgName)
  77. {
  78. const auto& lay = shaderAsset.FindShaderResourceGroupLayout(srgName, supervariantIndex);
  79. m_layout = lay.get();
  80. if (!m_layout)
  81. {
  82. AZ_Assert(false, "ShaderResourceGroup cannot be initialized due to invalid ShaderResourceGroupLayout");
  83. return RHI::ResultCode::Fail;
  84. }
  85. m_pool = ShaderResourceGroupPool::FindOrCreate(
  86. AZ::Data::Asset<ShaderAsset>(&shaderAsset, AZ::Data::AssetLoadBehavior::PreLoad), supervariantIndex, srgName);
  87. AZ_Assert(m_layout->GetHash() == m_pool->GetRHIPool()->GetLayout()->GetHash(), "This can happen if two shaders are including the same partial srg from different .azsl shader files and adding more custom entries to the srg. Recommendation is to just make a bigger SRG that can be shared between the two shaders.");
  88. if (!m_pool)
  89. {
  90. return RHI::ResultCode::Fail;
  91. }
  92. m_shaderResourceGroup = m_pool->CreateRHIShaderResourceGroup();
  93. if (!m_shaderResourceGroup)
  94. {
  95. return RHI::ResultCode::Fail;
  96. }
  97. m_shaderResourceGroup->SetName(m_pool->GetRHIPool()->GetName());
  98. m_data = RHI::ShaderResourceGroupData(RHI::MultiDevice::AllDevices, m_layout);
  99. m_asset = { &shaderAsset, AZ::Data::AssetLoadBehavior::PreLoad };
  100. // The RPI groups match the same dimensions as the RHI group.
  101. m_imageGroup.resize(m_layout->GetGroupSizeForImages());
  102. m_bufferGroup.resize(m_layout->GetGroupSizeForBuffers());
  103. m_isInitialized = true;
  104. return RHI::ResultCode::Success;
  105. }
  106. void ShaderResourceGroup::Compile()
  107. {
  108. m_shaderResourceGroup->Compile(m_data);
  109. //Mask is passed to RHI in the Compile call so we can reset it here
  110. m_data.ResetUpdateMask();
  111. }
  112. bool ShaderResourceGroup::IsQueuedForCompile() const
  113. {
  114. return m_shaderResourceGroup->IsQueuedForCompile();
  115. }
  116. RHI::ShaderInputBufferIndex ShaderResourceGroup::FindShaderInputBufferIndex(const Name& name) const
  117. {
  118. return m_layout->FindShaderInputBufferIndex(name);
  119. }
  120. RHI::ShaderInputImageIndex ShaderResourceGroup::FindShaderInputImageIndex(const Name& name) const
  121. {
  122. return m_layout->FindShaderInputImageIndex(name);
  123. }
  124. RHI::ShaderInputSamplerIndex ShaderResourceGroup::FindShaderInputSamplerIndex(const Name& name) const
  125. {
  126. return m_layout->FindShaderInputSamplerIndex(name);
  127. }
  128. RHI::ShaderInputConstantIndex ShaderResourceGroup::FindShaderInputConstantIndex(const Name& name) const
  129. {
  130. return m_layout->FindShaderInputConstantIndex(name);
  131. }
  132. RHI::ShaderInputBufferUnboundedArrayIndex ShaderResourceGroup::FindShaderInputBufferUnboundedArrayIndex(const Name& name) const
  133. {
  134. return m_layout->FindShaderInputBufferUnboundedArrayIndex(name);
  135. }
  136. RHI::ShaderInputImageUnboundedArrayIndex ShaderResourceGroup::FindShaderInputImageUnboundedArrayIndex(const Name& name) const
  137. {
  138. return m_layout->FindShaderInputImageUnboundedArrayIndex(name);
  139. }
  140. const RHI::ShaderResourceGroupLayout* ShaderResourceGroup::GetLayout() const
  141. {
  142. return m_layout;
  143. }
  144. RHI::ShaderResourceGroup* ShaderResourceGroup::GetRHIShaderResourceGroup()
  145. {
  146. return m_shaderResourceGroup.get();
  147. }
  148. bool ShaderResourceGroup::SetShaderVariantKeyFallbackValue(const ShaderVariantKey& shaderKey)
  149. {
  150. uint32_t keySize = GetLayout()->GetShaderVariantKeyFallbackSize();
  151. if (keySize == 0)
  152. {
  153. return false;
  154. }
  155. auto shaderFallbackIndex = GetLayout()->GetShaderVariantKeyFallbackConstantIndex();
  156. if (!shaderFallbackIndex.IsValid())
  157. {
  158. return false;
  159. }
  160. return SetConstantRaw(shaderFallbackIndex, shaderKey.data(), 0, AZStd::min(keySize, (uint32_t) ShaderVariantKeyBitCount) / 8);
  161. }
  162. bool ShaderResourceGroup::HasShaderVariantKeyFallbackEntry() const
  163. {
  164. return GetLayout()->HasShaderVariantKeyFallbackEntry();
  165. }
  166. bool ShaderResourceGroup::SetImage(RHI::ShaderInputNameIndex& inputIndex, const Data::Instance<Image>& image, uint32_t arrayIndex)
  167. {
  168. if (inputIndex.ValidateOrFindImageIndex(GetLayout()))
  169. {
  170. return SetImage(inputIndex.GetImageIndex(), image, arrayIndex);
  171. }
  172. return false;
  173. }
  174. bool ShaderResourceGroup::SetImage(RHI::ShaderInputImageIndex inputIndex, const Data::Instance<Image>& image, uint32_t arrayIndex)
  175. {
  176. const RHI::ImageView* imageView = image ? image->GetImageView() : nullptr;
  177. if (m_data.SetImageView(inputIndex, imageView, arrayIndex))
  178. {
  179. const RHI::Interval interval = m_layout->GetGroupInterval(inputIndex);
  180. // Track the RPI image entry at the same slot.
  181. m_imageGroup[interval.m_min + arrayIndex] = image;
  182. return true;
  183. }
  184. return false;
  185. }
  186. bool ShaderResourceGroup::SetImageArray(RHI::ShaderInputNameIndex& inputIndex, AZStd::span<const Data::Instance<Image>> images, uint32_t arrayIndex)
  187. {
  188. if (inputIndex.ValidateOrFindImageIndex(GetLayout()))
  189. {
  190. return SetImageArray(inputIndex.GetImageIndex(), images, arrayIndex);
  191. }
  192. return false;
  193. }
  194. bool ShaderResourceGroup::SetImageArray(RHI::ShaderInputImageIndex inputIndex, AZStd::span<const Data::Instance<Image>> images, uint32_t arrayIndex)
  195. {
  196. if (GetLayout()->ValidateAccess(inputIndex, arrayIndex + static_cast<uint32_t>(images.size()) - 1))
  197. {
  198. bool isValidAll = true;
  199. for (size_t i = 0; i < images.size(); ++i)
  200. {
  201. isValidAll &= SetImage(inputIndex, images[i], static_cast<uint32_t>(arrayIndex + i));
  202. }
  203. return isValidAll;
  204. }
  205. return false;
  206. }
  207. const Data::Instance<Image>& ShaderResourceGroup::GetImage(RHI::ShaderInputNameIndex& inputIndex, uint32_t arrayIndex) const
  208. {
  209. if (inputIndex.ValidateOrFindImageIndex(GetLayout()))
  210. {
  211. return GetImage(inputIndex.GetImageIndex(), arrayIndex);
  212. }
  213. return s_nullImage;
  214. }
  215. const Data::Instance<Image>& ShaderResourceGroup::GetImage(RHI::ShaderInputImageIndex inputIndex, uint32_t arrayIndex) const
  216. {
  217. if (m_layout->ValidateAccess(inputIndex, arrayIndex))
  218. {
  219. const RHI::Interval interval = m_layout->GetGroupInterval(inputIndex);
  220. return m_imageGroup[interval.m_min + arrayIndex];
  221. }
  222. return s_nullImage;
  223. }
  224. AZStd::span<const Data::Instance<Image>> ShaderResourceGroup::GetImageArray(RHI::ShaderInputNameIndex& inputIndex) const
  225. {
  226. if (inputIndex.ValidateOrFindImageIndex(GetLayout()))
  227. {
  228. return GetImageArray(inputIndex.GetImageIndex());
  229. }
  230. return {};
  231. }
  232. AZStd::span<const Data::Instance<Image>> ShaderResourceGroup::GetImageArray(RHI::ShaderInputImageIndex inputIndex) const
  233. {
  234. if (m_layout->ValidateAccess(inputIndex, 0))
  235. {
  236. const RHI::Interval interval = m_layout->GetGroupInterval(inputIndex);
  237. return AZStd::span<const Data::Instance<Image>>(&m_imageGroup[interval.m_min], interval.m_max - interval.m_min);
  238. }
  239. return {};
  240. }
  241. bool ShaderResourceGroup::SetImageView(RHI::ShaderInputNameIndex& inputIndex, const RHI::ImageView* imageView, uint32_t arrayIndex)
  242. {
  243. if (inputIndex.ValidateOrFindImageIndex(GetLayout()))
  244. {
  245. return SetImageView(inputIndex.GetImageIndex(), imageView, arrayIndex);
  246. }
  247. return false;
  248. }
  249. bool ShaderResourceGroup::SetImageView(RHI::ShaderInputImageIndex inputIndex, const RHI::ImageView* imageView, uint32_t arrayIndex)
  250. {
  251. if (m_data.SetImageView(inputIndex, imageView, arrayIndex))
  252. {
  253. const RHI::Interval interval = m_layout->GetGroupInterval(inputIndex);
  254. // Reset the RPI image entry, since an RHI version now takes precedence.
  255. m_imageGroup[interval.m_min + arrayIndex] = nullptr;
  256. return true;
  257. }
  258. return false;
  259. }
  260. bool ShaderResourceGroup::SetImageViewArray(RHI::ShaderInputNameIndex& inputIndex, AZStd::span<const RHI::ImageView* const> imageViews, uint32_t arrayIndex)
  261. {
  262. if (inputIndex.ValidateOrFindImageIndex(GetLayout()))
  263. {
  264. return SetImageViewArray(inputIndex.GetImageIndex(), imageViews, arrayIndex);
  265. }
  266. return false;
  267. }
  268. bool ShaderResourceGroup::SetImageViewArray(RHI::ShaderInputImageIndex inputIndex, AZStd::span<const RHI::ImageView * const> imageViews, uint32_t arrayIndex)
  269. {
  270. if (GetLayout()->ValidateAccess(inputIndex, arrayIndex + static_cast<uint32_t>(imageViews.size()) - 1))
  271. {
  272. bool isValidAll = true;
  273. for (size_t i = 0; i < imageViews.size(); ++i)
  274. {
  275. isValidAll &= SetImageView(inputIndex, imageViews[i], static_cast<uint32_t>(arrayIndex + i));
  276. }
  277. return isValidAll;
  278. }
  279. return false;
  280. }
  281. bool ShaderResourceGroup::SetImageViewUnboundedArray(RHI::ShaderInputImageUnboundedArrayIndex inputIndex, AZStd::span<const RHI::ImageView* const> imageViews)
  282. {
  283. return m_data.SetImageViewUnboundedArray(inputIndex, imageViews);
  284. }
  285. bool ShaderResourceGroup::SetBufferView(RHI::ShaderInputNameIndex& inputIndex, const RHI::BufferView *bufferView, uint32_t arrayIndex)
  286. {
  287. if (inputIndex.ValidateOrFindBufferIndex(GetLayout()))
  288. {
  289. return SetBufferView(inputIndex.GetBufferIndex(), bufferView, arrayIndex);
  290. }
  291. return false;
  292. }
  293. bool ShaderResourceGroup::SetBufferView(RHI::ShaderInputBufferIndex inputIndex, const RHI::BufferView *bufferView, uint32_t arrayIndex)
  294. {
  295. if (m_data.SetBufferView(inputIndex, bufferView, arrayIndex))
  296. {
  297. const RHI::Interval interval = m_layout->GetGroupInterval(inputIndex);
  298. // Reset the RPI image entry, since an RHI version now takes precedence.
  299. m_bufferGroup[interval.m_min + arrayIndex] = nullptr;
  300. return true;
  301. }
  302. return false;
  303. }
  304. bool ShaderResourceGroup::SetBufferViewArray(RHI::ShaderInputNameIndex& inputIndex, AZStd::span<const RHI::BufferView* const> bufferViews, uint32_t arrayIndex)
  305. {
  306. if (inputIndex.ValidateOrFindBufferIndex(GetLayout()))
  307. {
  308. return SetBufferViewArray(inputIndex.GetBufferIndex(), bufferViews, arrayIndex);
  309. }
  310. return false;
  311. }
  312. bool ShaderResourceGroup::SetBufferViewArray(RHI::ShaderInputBufferIndex inputIndex, AZStd::span<const RHI::BufferView * const> bufferViews, uint32_t arrayIndex)
  313. {
  314. if (GetLayout()->ValidateAccess(inputIndex, arrayIndex + static_cast<uint32_t>(bufferViews.size()) - 1))
  315. {
  316. bool isValidAll = true;
  317. for (size_t i = 0; i < bufferViews.size(); ++i)
  318. {
  319. isValidAll &= SetBufferView(inputIndex, bufferViews[i], static_cast<uint32_t>(arrayIndex + i));
  320. }
  321. return isValidAll;
  322. }
  323. return false;
  324. }
  325. bool ShaderResourceGroup::SetBufferViewUnboundedArray(RHI::ShaderInputBufferUnboundedArrayIndex inputIndex, AZStd::span<const RHI::BufferView * const> bufferViews)
  326. {
  327. return m_data.SetBufferViewUnboundedArray(inputIndex, bufferViews);
  328. }
  329. bool ShaderResourceGroup::SetSampler(RHI::ShaderInputNameIndex& inputIndex, const RHI::SamplerState& sampler, uint32_t arrayIndex)
  330. {
  331. if (inputIndex.ValidateOrFindSamplerIndex(GetLayout()))
  332. {
  333. return SetSampler(inputIndex.GetSamplerIndex(), sampler, arrayIndex);
  334. }
  335. return false;
  336. }
  337. bool ShaderResourceGroup::SetSampler(RHI::ShaderInputSamplerIndex inputIndex, const RHI::SamplerState& sampler, uint32_t arrayIndex)
  338. {
  339. return m_data.SetSampler(inputIndex, sampler, arrayIndex);
  340. }
  341. bool ShaderResourceGroup::SetSamplerArray(RHI::ShaderInputNameIndex& inputIndex, AZStd::span<const RHI::SamplerState> samplers, uint32_t arrayIndex)
  342. {
  343. if (inputIndex.ValidateOrFindSamplerIndex(GetLayout()))
  344. {
  345. return SetSamplerArray(inputIndex.GetSamplerIndex(), samplers, arrayIndex);
  346. }
  347. return false;
  348. }
  349. bool ShaderResourceGroup::SetSamplerArray(RHI::ShaderInputSamplerIndex inputIndex, AZStd::span<const RHI::SamplerState> samplers, uint32_t arrayIndex)
  350. {
  351. return m_data.SetSamplerArray(inputIndex, samplers, arrayIndex);
  352. }
  353. bool ShaderResourceGroup::SetConstantRaw(RHI::ShaderInputNameIndex& inputIndex, const void* bytes, uint32_t byteCount)
  354. {
  355. if (inputIndex.ValidateOrFindConstantIndex(GetLayout()))
  356. {
  357. return SetConstantRaw(inputIndex.GetConstantIndex(), bytes, byteCount);
  358. }
  359. return false;
  360. }
  361. bool ShaderResourceGroup::SetConstantRaw(RHI::ShaderInputConstantIndex inputIndex, const void* bytes, uint32_t byteCount)
  362. {
  363. return m_data.SetConstantRaw(inputIndex, bytes, byteCount);
  364. }
  365. bool ShaderResourceGroup::SetConstantRaw(RHI::ShaderInputNameIndex& inputIndex, const void* bytes, uint32_t byteOffset, uint32_t byteCount)
  366. {
  367. if (inputIndex.ValidateOrFindConstantIndex(GetLayout()))
  368. {
  369. return SetConstantRaw(inputIndex.GetConstantIndex(), bytes, byteOffset, byteCount);
  370. }
  371. return false;
  372. }
  373. bool ShaderResourceGroup::SetConstantRaw(RHI::ShaderInputConstantIndex inputIndex, const void* bytes, uint32_t byteOffset, uint32_t byteCount)
  374. {
  375. return m_data.SetConstantRaw(inputIndex, bytes, byteOffset, byteCount);
  376. }
  377. bool ShaderResourceGroup::ApplyDataMappings(const RHI::ShaderDataMappings& mappings)
  378. {
  379. bool success = true;
  380. success = success && ApplyDataMappingArray(mappings.m_colorMappings);
  381. success = success && ApplyDataMappingArray(mappings.m_uintMappings);
  382. success = success && ApplyDataMappingArray(mappings.m_floatMappings);
  383. success = success && ApplyDataMappingArray(mappings.m_float2Mappings);
  384. success = success && ApplyDataMappingArray(mappings.m_float3Mappings);
  385. success = success && ApplyDataMappingArray(mappings.m_float4Mappings);
  386. success = success && ApplyDataMappingArray(mappings.m_matrix3x3Mappings);
  387. success = success && ApplyDataMappingArray(mappings.m_matrix4x4Mappings);
  388. return success;
  389. }
  390. const RHI::ConstPtr<RHI::ImageView>& ShaderResourceGroup::GetImageView(
  391. RHI::ShaderInputNameIndex& inputIndex, uint32_t arrayIndex) const
  392. {
  393. inputIndex.ValidateOrFindImageIndex(GetLayout());
  394. return GetImageView(inputIndex.GetImageIndex(), arrayIndex);
  395. }
  396. const RHI::ConstPtr<RHI::ImageView>& ShaderResourceGroup::GetImageView(RHI::ShaderInputImageIndex inputIndex, uint32_t arrayIndex) const
  397. {
  398. return m_data.GetImageView(inputIndex, arrayIndex);
  399. }
  400. AZStd::span<const RHI::ConstPtr<RHI::ImageView>> ShaderResourceGroup::GetImageViewArray(
  401. RHI::ShaderInputNameIndex& inputIndex) const
  402. {
  403. inputIndex.ValidateOrFindImageIndex(GetLayout());
  404. return GetImageViewArray(inputIndex.GetImageIndex());
  405. }
  406. AZStd::span<const RHI::ConstPtr<RHI::ImageView>> ShaderResourceGroup::GetImageViewArray(RHI::ShaderInputImageIndex inputIndex) const
  407. {
  408. return m_data.GetImageViewArray(inputIndex);
  409. }
  410. const RHI::ConstPtr<RHI::BufferView>& ShaderResourceGroup::GetBufferView(
  411. RHI::ShaderInputNameIndex& inputIndex, uint32_t arrayIndex) const
  412. {
  413. inputIndex.ValidateOrFindBufferIndex(GetLayout());
  414. return GetBufferView(inputIndex.GetBufferIndex(), arrayIndex);
  415. }
  416. const RHI::ConstPtr<RHI::BufferView>& ShaderResourceGroup::GetBufferView(RHI::ShaderInputBufferIndex inputIndex, uint32_t arrayIndex) const
  417. {
  418. return m_data.GetBufferView(inputIndex, arrayIndex);
  419. }
  420. AZStd::span<const RHI::ConstPtr<RHI::BufferView>> ShaderResourceGroup::GetBufferViewArray(
  421. RHI::ShaderInputNameIndex& inputIndex) const
  422. {
  423. inputIndex.ValidateOrFindBufferIndex(GetLayout());
  424. return GetBufferViewArray(inputIndex.GetBufferIndex());
  425. }
  426. AZStd::span<const RHI::ConstPtr<RHI::BufferView>> ShaderResourceGroup::GetBufferViewArray(RHI::ShaderInputBufferIndex inputIndex) const
  427. {
  428. return m_data.GetBufferViewArray(inputIndex);
  429. }
  430. bool ShaderResourceGroup::SetBuffer(RHI::ShaderInputNameIndex& inputIndex, const Data::Instance<Buffer>& buffer, uint32_t arrayIndex)
  431. {
  432. if (inputIndex.ValidateOrFindBufferIndex(GetLayout()))
  433. {
  434. return SetBuffer(inputIndex.GetBufferIndex(), buffer, arrayIndex);
  435. }
  436. return false;
  437. }
  438. bool ShaderResourceGroup::SetBuffer(RHI::ShaderInputBufferIndex inputIndex, const Data::Instance<Buffer>& buffer, uint32_t arrayIndex)
  439. {
  440. const auto bufferView =
  441. buffer ? buffer->GetBufferView() : nullptr;
  442. if (m_data.SetBufferView(inputIndex, bufferView, arrayIndex))
  443. {
  444. const RHI::Interval interval = m_layout->GetGroupInterval(inputIndex);
  445. // Track the RPI buffer entry at the same slot.
  446. m_bufferGroup[interval.m_min + arrayIndex] = buffer;
  447. return true;
  448. }
  449. return false;
  450. }
  451. bool ShaderResourceGroup::SetBufferArray(RHI::ShaderInputNameIndex& inputIndex, AZStd::span<const Data::Instance<Buffer>> buffers, uint32_t arrayIndex)
  452. {
  453. if (inputIndex.ValidateOrFindBufferIndex(GetLayout()))
  454. {
  455. return SetBufferArray(inputIndex.GetBufferIndex(), buffers, arrayIndex);
  456. }
  457. return false;
  458. }
  459. bool ShaderResourceGroup::SetBufferArray(RHI::ShaderInputBufferIndex inputIndex, AZStd::span<const Data::Instance<Buffer>> buffers, uint32_t arrayIndex)
  460. {
  461. if (GetLayout()->ValidateAccess(inputIndex, arrayIndex + static_cast<uint32_t>(buffers.size()) - 1))
  462. {
  463. bool isValidAll = true;
  464. for (size_t i = 0; i < buffers.size(); ++i)
  465. {
  466. isValidAll &= SetBuffer(inputIndex, buffers[i], static_cast<uint32_t>(arrayIndex + i));
  467. }
  468. return isValidAll;
  469. }
  470. return false;
  471. }
  472. const Data::Instance<Buffer>& ShaderResourceGroup::GetBuffer(RHI::ShaderInputNameIndex& inputIndex, uint32_t arrayIndex) const
  473. {
  474. if (inputIndex.ValidateOrFindBufferIndex(GetLayout()))
  475. {
  476. return GetBuffer(inputIndex.GetBufferIndex(), arrayIndex);
  477. }
  478. return s_nullBuffer;
  479. }
  480. const Data::Instance<Buffer>& ShaderResourceGroup::GetBuffer(RHI::ShaderInputBufferIndex inputIndex, uint32_t arrayIndex) const
  481. {
  482. if (m_layout->ValidateAccess(inputIndex, arrayIndex))
  483. {
  484. const RHI::Interval interval = m_layout->GetGroupInterval(inputIndex);
  485. return m_bufferGroup[interval.m_min + arrayIndex];
  486. }
  487. return s_nullBuffer;
  488. }
  489. AZStd::span<const Data::Instance<Buffer>> ShaderResourceGroup::GetBufferArray(RHI::ShaderInputNameIndex& inputIndex) const
  490. {
  491. if (inputIndex.ValidateOrFindBufferIndex(GetLayout()))
  492. {
  493. return GetBufferArray(inputIndex.GetBufferIndex());
  494. }
  495. return {};
  496. }
  497. AZStd::span<const Data::Instance<Buffer>> ShaderResourceGroup::GetBufferArray(RHI::ShaderInputBufferIndex inputIndex) const
  498. {
  499. if (m_layout->ValidateAccess(inputIndex, 0))
  500. {
  501. const RHI::Interval interval = m_layout->GetGroupInterval(inputIndex);
  502. return AZStd::span<const Data::Instance<Buffer>>(&m_bufferGroup[interval.m_min], interval.m_max - interval.m_min);
  503. }
  504. return {};
  505. }
  506. void ShaderResourceGroup::ResetViews()
  507. {
  508. m_data.ResetViews();
  509. }
  510. const RHI::SamplerState& ShaderResourceGroup::GetSampler(RHI::ShaderInputNameIndex& inputIndex, uint32_t arrayIndex) const
  511. {
  512. inputIndex.ValidateOrFindSamplerIndex(GetLayout());
  513. return GetSampler(inputIndex.GetSamplerIndex(), arrayIndex);
  514. }
  515. const RHI::SamplerState& ShaderResourceGroup::GetSampler(RHI::ShaderInputSamplerIndex inputIndex, uint32_t arrayIndex) const
  516. {
  517. return m_data.GetSampler(inputIndex, arrayIndex);
  518. }
  519. AZStd::span<const RHI::SamplerState> ShaderResourceGroup::GetSamplerArray(RHI::ShaderInputNameIndex& inputIndex) const
  520. {
  521. inputIndex.ValidateOrFindSamplerIndex(GetLayout());
  522. return GetSamplerArray(inputIndex.GetSamplerIndex());
  523. }
  524. AZStd::span<const RHI::SamplerState> ShaderResourceGroup::GetSamplerArray(RHI::ShaderInputSamplerIndex inputIndex) const
  525. {
  526. return m_data.GetSamplerArray(inputIndex);
  527. }
  528. AZStd::span<const uint8_t> ShaderResourceGroup::GetConstantRaw(RHI::ShaderInputNameIndex& inputIndex) const
  529. {
  530. inputIndex.ValidateOrFindConstantIndex(GetLayout());
  531. return GetConstantRaw(inputIndex.GetConstantIndex());
  532. }
  533. AZStd::span<const uint8_t> ShaderResourceGroup::GetConstantRaw(RHI::ShaderInputConstantIndex inputIndex) const
  534. {
  535. return m_data.GetConstantRaw(inputIndex);
  536. }
  537. bool ShaderResourceGroup::CopyShaderResourceGroupData(const ShaderResourceGroup& other)
  538. {
  539. bool isFullCopy = true;
  540. // Copy Buffer Shader Inputs
  541. for (const RHI::ShaderInputBufferDescriptor& desc : m_layout->GetShaderInputListForBuffers())
  542. {
  543. RHI::ShaderInputBufferIndex otherIndex = other.m_layout->FindShaderInputBufferIndex(desc.m_name);
  544. if (!otherIndex.IsValid())
  545. {
  546. isFullCopy = false;
  547. continue;
  548. }
  549. [[maybe_unused]] const RHI::ShaderInputBufferDescriptor& otherDesc = other.m_layout->GetShaderInput(otherIndex);
  550. AZ_Error(
  551. "ShaderResourceGroup",
  552. desc.m_access == otherDesc.m_access && desc.m_count == otherDesc.m_count &&
  553. desc.m_strideSize == otherDesc.m_strideSize && desc.m_type == otherDesc.m_type,
  554. "ShaderInputBuffer %s does not match when copying shader resource group data",
  555. desc.m_name.GetCStr());
  556. RHI::ShaderInputBufferIndex index = m_layout->FindShaderInputBufferIndex(desc.m_name);
  557. auto bufferViewArray = other.GetBufferViewArray(otherIndex);
  558. auto bufferArray = other.GetBufferArray(otherIndex);
  559. AZ_Assert(bufferViewArray.size() == bufferArray.size(), "Different size between buffers and buffer views");
  560. for (uint32_t i = 0; i < bufferViewArray.size(); ++i)
  561. {
  562. if (bufferArray[i])
  563. {
  564. SetBuffer(index, bufferArray[i], i);
  565. }
  566. else
  567. {
  568. SetBufferView(index, bufferViewArray[i].get(), i);
  569. }
  570. }
  571. }
  572. // Copy Image Shader Inputs
  573. for (const RHI::ShaderInputImageDescriptor& desc : m_layout->GetShaderInputListForImages())
  574. {
  575. RHI::ShaderInputImageIndex otherIndex = other.m_layout->FindShaderInputImageIndex(desc.m_name);
  576. if (!otherIndex.IsValid())
  577. {
  578. isFullCopy = false;
  579. continue;
  580. }
  581. [[maybe_unused]] const RHI::ShaderInputImageDescriptor& otherDesc = other.m_layout->GetShaderInput(otherIndex);
  582. AZ_Error(
  583. "ShaderResourceGroup",
  584. desc.m_access == otherDesc.m_access && desc.m_count == otherDesc.m_count &&
  585. desc.m_type == otherDesc.m_type,
  586. "ShaderInputImage %s does not match when copying shader resource group data",
  587. desc.m_name.GetCStr());
  588. RHI::ShaderInputImageIndex index = m_layout->FindShaderInputImageIndex(desc.m_name);
  589. auto imageViewArray = other.GetImageViewArray(otherIndex);
  590. auto imageArray = other.GetImageArray(otherIndex);
  591. AZ_Assert(imageViewArray.size() == imageArray.size(), "Different size between image and image views");
  592. for (uint32_t i = 0; i < imageViewArray.size(); ++i)
  593. {
  594. if (imageArray[i])
  595. {
  596. SetImage(index, imageArray[i].get(), i);
  597. }
  598. else
  599. {
  600. SetImageView(index, imageViewArray[i].get(), i);
  601. }
  602. }
  603. }
  604. // Copy Sample Shader Inputs
  605. for (const RHI::ShaderInputSamplerDescriptor& desc : m_layout->GetShaderInputListForSamplers())
  606. {
  607. RHI::ShaderInputSamplerIndex otherIndex = other.m_layout->FindShaderInputSamplerIndex(desc.m_name);
  608. if (!otherIndex.IsValid())
  609. {
  610. isFullCopy = false;
  611. continue;
  612. }
  613. [[maybe_unused]] const RHI::ShaderInputSamplerDescriptor& otherDesc = other.m_layout->GetShaderInput(otherIndex);
  614. AZ_Error(
  615. "ShaderResourceGroup",
  616. desc.m_count == otherDesc.m_count,
  617. "ShaderInputSampler %s does not match when copying shader resource group data",
  618. desc.m_name.GetCStr());
  619. AZStd::span<const RHI::SamplerState> samplerViewArray = other.m_data.GetSamplerArray(otherIndex);
  620. SetSamplerArray(m_layout->FindShaderInputSamplerIndex(desc.m_name), samplerViewArray);
  621. }
  622. // Copy Constants Shader Inputs
  623. for (const RHI::ShaderInputConstantDescriptor& desc : m_layout->GetShaderInputListForConstants())
  624. {
  625. RHI::ShaderInputConstantIndex otherIndex = other.m_layout->FindShaderInputConstantIndex(desc.m_name);
  626. if (!otherIndex.IsValid())
  627. {
  628. isFullCopy = false;
  629. continue;
  630. }
  631. [[maybe_unused]] const RHI::ShaderInputConstantDescriptor& otherDesc = other.m_layout->GetShaderInput(otherIndex);
  632. AZ_Error(
  633. "ShaderResourceGroup",
  634. desc.m_constantByteCount == otherDesc.m_constantByteCount,
  635. "ShaderInputConstant %s does not match when copying shader resource group data",
  636. desc.m_name.GetCStr());
  637. AZStd::span<const uint8_t> constantRaw = other.m_data.GetConstantRaw(otherIndex);
  638. SetConstantRaw(m_layout->FindShaderInputConstantIndex(desc.m_name), constantRaw.data(), static_cast<uint32_t>(constantRaw.size()));
  639. }
  640. // Copy Unbound Buffer Array Inputs
  641. AZStd::vector<const RHI::BufferView*> bufferViewPtrArray;
  642. for (const RHI::ShaderInputBufferUnboundedArrayDescriptor& desc : m_layout->GetShaderInputListForBufferUnboundedArrays())
  643. {
  644. RHI::ShaderInputBufferUnboundedArrayIndex otherIndex =
  645. other.m_layout->FindShaderInputBufferUnboundedArrayIndex(desc.m_name);
  646. if (!otherIndex.IsValid())
  647. {
  648. isFullCopy = false;
  649. continue;
  650. }
  651. [[maybe_unused]] const RHI::ShaderInputBufferUnboundedArrayDescriptor& otherDesc =
  652. other.m_layout->GetShaderInput(otherIndex);
  653. AZ_Error(
  654. "ShaderResourceGroup",
  655. desc.m_type == otherDesc.m_type && desc.m_access == otherDesc.m_access,
  656. "ShaderInputBufferUnboundedArray %s does not match when copying shader resource group data",
  657. desc.m_name.GetCStr());
  658. AZStd::span<const RHI::ConstPtr<RHI::BufferView>> bufferViewArray = other.m_data.GetBufferViewUnboundedArray(otherIndex);
  659. bufferViewPtrArray.resize(bufferViewArray.size());
  660. AZStd::transform(
  661. bufferViewArray.begin(),
  662. bufferViewArray.end(),
  663. bufferViewPtrArray.begin(),
  664. [](auto& item)
  665. {
  666. return item.get();
  667. });
  668. SetBufferViewUnboundedArray(m_layout->FindShaderInputBufferUnboundedArrayIndex(desc.m_name), bufferViewPtrArray);
  669. }
  670. // Copy Unbound Image Array Inputs
  671. AZStd::vector<const RHI::ImageView*> imageViewPtrArray;
  672. for (const RHI::ShaderInputImageUnboundedArrayDescriptor& desc : m_layout->GetShaderInputListForImageUnboundedArrays())
  673. {
  674. RHI::ShaderInputImageUnboundedArrayIndex otherIndex = other.m_layout->FindShaderInputImageUnboundedArrayIndex(desc.m_name);
  675. if (!otherIndex.IsValid())
  676. {
  677. isFullCopy = false;
  678. continue;
  679. }
  680. [[maybe_unused]] const RHI::ShaderInputImageUnboundedArrayDescriptor& otherDesc =
  681. other.m_layout->GetShaderInput(otherIndex);
  682. AZ_Error(
  683. "ShaderResourceGroup",
  684. desc.m_type == otherDesc.m_type && desc.m_access == otherDesc.m_access,
  685. "ShaderInputImageUnboundedArray %s does not match when copying shader resource group data",
  686. desc.m_name.GetCStr());
  687. AZStd::span<const RHI::ConstPtr<RHI::ImageView>> imageViewArray = other.m_data.GetImageViewUnboundedArray(otherIndex);
  688. imageViewPtrArray.resize(imageViewArray.size());
  689. AZStd::transform(
  690. imageViewArray.begin(),
  691. imageViewArray.end(),
  692. imageViewPtrArray.begin(),
  693. [](auto& item)
  694. {
  695. return item.get();
  696. });
  697. SetImageViewUnboundedArray(m_layout->FindShaderInputImageUnboundedArrayIndex(desc.m_name), imageViewPtrArray);
  698. }
  699. // Copy Bindless Inputs
  700. AZStd::vector<bool> isReadOnlyBuffer;
  701. AZStd::vector<bool> isReadOnlyImage;
  702. for (const auto& entry : other.m_data.GetBindlessResourceViews())
  703. {
  704. RHI::ShaderInputBufferIndex otherIndirectResourceBufferIndex = entry.first.first;
  705. const RHI::ConstPtr<RHI::BufferView>& indirectResourceBuffer = other.GetBufferView(otherIndirectResourceBufferIndex);
  706. RHI::ShaderInputBufferDescriptor otherIndirectResourceBufferDesc =
  707. other.m_layout->GetShaderInput(otherIndirectResourceBufferIndex);
  708. RHI::ShaderInputBufferIndex indirectResourceBufferIndex =
  709. FindShaderInputBufferIndex(otherIndirectResourceBufferDesc.m_name);
  710. if (!indirectResourceBufferIndex.IsValid())
  711. {
  712. continue;
  713. }
  714. size_t maxSize = entry.second.m_bindlessResources.size();
  715. bufferViewPtrArray.clear();
  716. bufferViewPtrArray.reserve(maxSize);
  717. imageViewPtrArray.clear();
  718. imageViewPtrArray.reserve(maxSize);
  719. isReadOnlyBuffer.clear();
  720. isReadOnlyBuffer.reserve(maxSize);
  721. isReadOnlyImage.clear();
  722. isReadOnlyImage.reserve(maxSize);
  723. for (ConstPtr<RHI::ResourceView> resourceView : entry.second.m_bindlessResources)
  724. {
  725. RHI::BindlessResourceType type = entry.second.m_bindlessResourceType;
  726. switch (type)
  727. {
  728. case RHI::BindlessResourceType::m_ByteAddressBuffer:
  729. case RHI::BindlessResourceType::m_RWByteAddressBuffer:
  730. bufferViewPtrArray.push_back(static_cast<const RHI::BufferView*>(resourceView.get()));
  731. isReadOnlyBuffer.push_back(type != RHI::BindlessResourceType::m_RWByteAddressBuffer);
  732. break;
  733. case RHI::BindlessResourceType::m_Texture2D:
  734. case RHI::BindlessResourceType::m_RWTexture2D:
  735. case RHI::BindlessResourceType::m_TextureCube:
  736. imageViewPtrArray.push_back(static_cast<const RHI::ImageView*>(resourceView.get()));
  737. isReadOnlyBuffer.push_back(type != RHI::BindlessResourceType::m_RWTexture2D);
  738. break;
  739. default:
  740. AZ_Assert(false, "Invalid RHI::BindlessResourceType %d", type);
  741. continue;
  742. }
  743. }
  744. if (!bufferViewPtrArray.empty())
  745. {
  746. SetBindlessViews(
  747. indirectResourceBufferIndex,
  748. indirectResourceBuffer.get(),
  749. bufferViewPtrArray,
  750. nullptr,
  751. isReadOnlyBuffer,
  752. entry.first.second);
  753. }
  754. if (!imageViewPtrArray.empty())
  755. {
  756. SetBindlessViews(
  757. indirectResourceBufferIndex,
  758. indirectResourceBuffer.get(),
  759. imageViewPtrArray,
  760. nullptr,
  761. isReadOnlyImage,
  762. entry.first.second);
  763. }
  764. }
  765. return isFullCopy;
  766. }
  767. void ShaderResourceGroup::SetBindlessViews(
  768. RHI::ShaderInputBufferIndex indirectResourceBufferIndex,
  769. const RHI::BufferView* indirectResourceBuffer,
  770. AZStd::span<const RHI::ImageView* const> imageViews,
  771. uint32_t* outIndices,
  772. AZStd::span<bool> isViewReadOnly,
  773. uint32_t arrayIndex)
  774. {
  775. m_data.SetBindlessViews(
  776. indirectResourceBufferIndex, indirectResourceBuffer, imageViews, outIndices, isViewReadOnly, arrayIndex);
  777. }
  778. void ShaderResourceGroup::SetBindlessViews(
  779. RHI::ShaderInputBufferIndex indirectResourceBufferIndex,
  780. const RHI::BufferView* indirectResourceBuffer,
  781. AZStd::span<const RHI::BufferView* const> bufferViews,
  782. uint32_t* outIndices,
  783. AZStd::span<bool> isViewReadOnly,
  784. uint32_t arrayIndex)
  785. {
  786. m_data.SetBindlessViews(indirectResourceBufferIndex, indirectResourceBuffer,
  787. bufferViews, outIndices, isViewReadOnly, arrayIndex);
  788. }
  789. } // namespace RPI
  790. } // namespace AZ