MultiDeviceDrawPacketTests.cpp 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include "RHITestFixture.h"
  9. #include <Atom/RHI.Reflect/RenderAttachmentLayout.h>
  10. #include <Atom/RHI/DrawListContext.h>
  11. #include <Atom/RHI/DrawListTagRegistry.h>
  12. #include <Atom/RHI/DrawPacket.h>
  13. #include <Atom/RHI/DrawPacketBuilder.h>
  14. #include <Atom/RHI/PipelineState.h>
  15. #include <Atom/RHI/ShaderResourceGroupPool.h>
  16. #include <AzCore/Math/Random.h>
  17. #include <AzCore/std/sort.h>
  18. #include <Tests/Device.h>
  19. #include <Tests/Factory.h>
  20. namespace UnitTest
  21. {
  22. using namespace AZ;
  23. //? TODO: May revert back to normal deviceCount and Mask
  24. static constexpr auto LocalDeviceCount{1};
  25. static constexpr auto LocalDeviceMask{RHI::MultiDevice::DefaultDevice};
  26. struct MultiDeviceDrawItemData
  27. {
  28. MultiDeviceDrawItemData(SimpleLcgRandom& random, const RHI::Buffer* bufferEmpty, const RHI::PipelineState* psoEmpty)
  29. {
  30. m_pipelineState = psoEmpty;
  31. // Fill with deterministic random data to compare against.
  32. for (auto& streamBufferView : m_streamBufferViews)
  33. {
  34. streamBufferView =
  35. RHI::StreamBufferView{ *bufferEmpty, random.GetRandom(), random.GetRandom(), random.GetRandom() };
  36. }
  37. m_tag = RHI::DrawListTag(random.GetRandom() % RHI::Limits::Pipeline::DrawListTagCountMax);
  38. m_stencilRef = static_cast<uint8_t>(random.GetRandom());
  39. m_sortKey = random.GetRandom();
  40. }
  41. AZStd::array<RHI::StreamBufferView, RHI::Limits::Pipeline::StreamCountMax> m_streamBufferViews;
  42. const RHI::PipelineState* m_pipelineState;
  43. RHI::DrawListTag m_tag;
  44. RHI::DrawItemSortKey m_sortKey;
  45. uint8_t m_stencilRef;
  46. };
  47. struct MultiDeviceDrawPacketData
  48. {
  49. static constexpr const size_t DrawItemCountMax = 8;
  50. MultiDeviceDrawPacketData(SimpleLcgRandom& random)
  51. {
  52. RHI::BufferPoolDescriptor bufferPoolDesc;
  53. m_bufferPool = aznew RHI::BufferPool;
  54. m_bufferEmpty = aznew RHI::Buffer;
  55. m_bufferPool->Init(LocalDeviceMask, bufferPoolDesc);
  56. RHI::BufferInitRequest request;
  57. request.m_buffer = m_bufferEmpty.get();
  58. request.m_descriptor = RHI::BufferDescriptor{};
  59. m_bufferPool->InitBuffer(request);
  60. m_psoEmpty = aznew RHI::PipelineState;
  61. m_psoEmpty->m_deviceMask = LocalDeviceMask;
  62. m_psoEmpty->IterateDevices(
  63. [this](int deviceIndex)
  64. {
  65. this->m_psoEmpty->m_deviceObjects[deviceIndex] = RHI::Factory::Get().CreatePipelineState();
  66. return true;
  67. });
  68. for (auto& srg : m_srgs)
  69. {
  70. srg = aznew RHI::ShaderResourceGroup;
  71. srg->m_deviceMask = LocalDeviceMask;
  72. srg->IterateDevices(
  73. [&srg](int deviceIndex)
  74. {
  75. srg->m_deviceObjects[deviceIndex] = RHI::Factory::Get().CreateShaderResourceGroup();
  76. return true;
  77. });
  78. }
  79. unsigned int* data = reinterpret_cast<unsigned int*>(m_rootConstants.data());
  80. for (uint32_t i = 0; i < m_rootConstants.size() / sizeof(unsigned int); ++i)
  81. {
  82. data[i] = random.GetRandom();
  83. }
  84. for (size_t i = 0; i < DrawItemCountMax; ++i)
  85. {
  86. m_drawItemDatas.emplace_back(random, m_bufferEmpty.get(), m_psoEmpty.get());
  87. }
  88. m_indexBufferView =
  89. RHI::IndexBufferView(*m_bufferEmpty, random.GetRandom(), random.GetRandom(), RHI::IndexFormat::Uint16);
  90. }
  91. const auto Build(RHI::DrawPacketBuilder& builder)
  92. {
  93. builder.Begin(nullptr);
  94. for (auto& srgPtr : m_srgs)
  95. {
  96. builder.AddShaderResourceGroup(srgPtr.get());
  97. }
  98. builder.SetRootConstants(m_rootConstants);
  99. builder.SetIndexBufferView(m_indexBufferView);
  100. RHI::DrawListMask drawListMask;
  101. for (size_t i = 0; i < DrawItemCountMax; ++i)
  102. {
  103. const MultiDeviceDrawItemData& drawItemData = m_drawItemDatas[i];
  104. drawListMask[drawItemData.m_tag.GetIndex()] = true;
  105. RHI::DrawPacketBuilder::DrawRequest drawRequest;
  106. drawRequest.m_listTag = drawItemData.m_tag;
  107. drawRequest.m_sortKey = drawItemData.m_sortKey;
  108. drawRequest.m_stencilRef = drawItemData.m_stencilRef;
  109. drawRequest.m_streamBufferViews = drawItemData.m_streamBufferViews;
  110. drawRequest.m_pipelineState = drawItemData.m_pipelineState;
  111. builder.AddDrawItem(drawRequest);
  112. }
  113. const auto drawPacket = builder.End();
  114. EXPECT_NE(drawPacket, nullptr);
  115. EXPECT_EQ(drawPacket->GetDrawListMask(), drawListMask);
  116. EXPECT_EQ(drawPacket->GetDrawItemCount(), m_drawItemDatas.size());
  117. return drawPacket;
  118. }
  119. RHI::Ptr<RHI::BufferPool> m_bufferPool;
  120. RHI::Ptr<RHI::Buffer> m_bufferEmpty;
  121. RHI::Ptr<RHI::PipelineState> m_psoEmpty;
  122. RHI::Ptr<RHI::ShaderResourceGroupPool> m_srgPool;
  123. AZStd::array<RHI::Ptr<RHI::ShaderResourceGroup>, RHI::Limits::Pipeline::ShaderResourceGroupCountMax> m_srgs;
  124. AZStd::array<uint8_t, sizeof(unsigned int) * 4> m_rootConstants;
  125. RHI::IndexBufferView m_indexBufferView;
  126. AZStd::vector<MultiDeviceDrawItemData> m_drawItemDatas;
  127. };
  128. class MultiDeviceDrawPacketTest : public MultiDeviceRHITestFixture
  129. {
  130. protected:
  131. static const uint32_t s_randomSeed = 1234;
  132. RHI::DrawListContext m_drawListContext;
  133. AZStd::unique_ptr<AZ::RHI::RHISystem> m_rhiSystem;
  134. AZStd::unique_ptr<Factory> m_factory;
  135. public:
  136. void SetUp() override
  137. {
  138. MultiDeviceRHITestFixture::SetUp();
  139. }
  140. void TearDown() override
  141. {
  142. MultiDeviceRHITestFixture::TearDown();
  143. }
  144. void DrawPacketEmpty()
  145. {
  146. RHI::DrawPacketBuilder builder(LocalDeviceMask);
  147. builder.Begin(nullptr);
  148. const auto drawPacket = builder.End();
  149. EXPECT_EQ(drawPacket, nullptr);
  150. }
  151. void DrawPacketNullItem()
  152. {
  153. RHI::DeviceDrawPacketBuilder builder;
  154. builder.Begin(nullptr);
  155. RHI::DeviceDrawPacketBuilder::DeviceDrawRequest drawRequest;
  156. builder.AddDrawItem(drawRequest);
  157. const RHI::DeviceDrawPacket* drawPacket = builder.End();
  158. EXPECT_EQ(drawPacket, nullptr);
  159. }
  160. void DrawPacketBuild()
  161. {
  162. AZ::SimpleLcgRandom random(s_randomSeed);
  163. MultiDeviceDrawPacketData drawPacketData(random);
  164. RHI::DrawPacketBuilder builder(LocalDeviceMask);
  165. const auto drawPacket = drawPacketData.Build(builder);
  166. }
  167. void DrawPacketBuildClearBuildNull()
  168. {
  169. AZ::SimpleLcgRandom random(s_randomSeed);
  170. MultiDeviceDrawPacketData drawPacketData(random);
  171. RHI::DrawPacketBuilder builder(LocalDeviceMask);
  172. auto drawPacket = drawPacketData.Build(builder);
  173. // Try to build a 'null' packet. This should result in a null pointer.
  174. builder.Begin(nullptr);
  175. drawPacket = builder.End();
  176. EXPECT_EQ(drawPacket.get(), nullptr);
  177. }
  178. void DrawListContextFilter()
  179. {
  180. AZ::SimpleLcgRandom random(s_randomSeed);
  181. MultiDeviceDrawPacketData drawPacketData(random);
  182. RHI::DrawPacketBuilder builder(LocalDeviceMask);
  183. auto drawPacket = drawPacketData.Build(builder);
  184. RHI::DrawListContext drawListContext;
  185. drawListContext.Init(RHI::DrawListMask{}.set());
  186. drawListContext.AddDrawPacket(drawPacket.get());
  187. for (size_t i = 0; i < drawPacket->GetDrawItemCount(); ++i)
  188. {
  189. RHI::DrawListTag tag = drawPacket->GetDrawListTag(i);
  190. RHI::DrawListView drawList = drawListContext.GetList(tag);
  191. EXPECT_TRUE(drawList.empty());
  192. }
  193. drawListContext.FinalizeLists();
  194. RHI::DrawListsByTag listsByTag;
  195. for (size_t i = 0; i < drawPacket->GetDrawItemCount(); ++i)
  196. {
  197. RHI::DrawListTag tag = drawPacket->GetDrawListTag(i);
  198. listsByTag[tag.GetIndex()].push_back(drawPacket->GetDrawItemProperties(i));
  199. }
  200. size_t tagIndex = 0;
  201. for (auto& drawList : listsByTag)
  202. {
  203. SortDrawList(drawList, RHI::DrawListSortType::KeyThenDepth);
  204. RHI::DrawListTag tag(tagIndex);
  205. RHI::DrawListView drawListView = drawListContext.GetList(tag);
  206. EXPECT_EQ(drawListView.size(), drawList.size());
  207. for (size_t i = 0; i < drawList.size(); ++i)
  208. {
  209. EXPECT_EQ(drawList[i], drawListView[i]);
  210. }
  211. tagIndex++;
  212. }
  213. drawListContext.Shutdown();
  214. }
  215. void DrawListContextNullFilter()
  216. {
  217. AZ::SimpleLcgRandom random(s_randomSeed);
  218. MultiDeviceDrawPacketData drawPacketData(random);
  219. RHI::DrawPacketBuilder builder{RHI::MultiDevice::DefaultDevice};
  220. auto drawPacket = drawPacketData.Build(builder);
  221. RHI::DrawListContext drawListContext;
  222. drawListContext.Init(RHI::DrawListMask{}); // Mask set to not contain any draw lists.
  223. drawListContext.AddDrawPacket(drawPacket.get());
  224. drawListContext.FinalizeLists();
  225. for (size_t i = 0; i < drawPacket->GetDrawItemCount(); ++i)
  226. {
  227. RHI::DrawListTag tag = drawPacket->GetDrawListTag(i);
  228. RHI::DrawListView drawList = drawListContext.GetList(tag);
  229. EXPECT_TRUE(drawList.empty());
  230. }
  231. drawListContext.Shutdown();
  232. }
  233. void DrawPacketClone()
  234. {
  235. AZ::SimpleLcgRandom random(s_randomSeed);
  236. MultiDeviceDrawPacketData drawPacketData(random);
  237. RHI::DrawPacketBuilder builder(LocalDeviceMask);
  238. const auto drawPacket = drawPacketData.Build(builder);
  239. RHI::DrawPacketBuilder builder2(LocalDeviceMask);
  240. const auto drawPacketClone = builder2.Clone(drawPacket.get());
  241. for (auto deviceIndex{ 0 }; deviceIndex < LocalDeviceCount; ++deviceIndex)
  242. {
  243. auto deviceDrawPacket{ drawPacket->GetDeviceDrawPacket(deviceIndex) };
  244. auto deviceDrawPacketClone{ drawPacketClone->GetDeviceDrawPacket(deviceIndex) };
  245. EXPECT_EQ(deviceDrawPacket->m_drawItemCount, deviceDrawPacketClone->m_drawItemCount);
  246. EXPECT_EQ(deviceDrawPacket->m_streamBufferViewCount, deviceDrawPacketClone->m_streamBufferViewCount);
  247. EXPECT_EQ(deviceDrawPacket->m_shaderResourceGroupCount, deviceDrawPacketClone->m_shaderResourceGroupCount);
  248. EXPECT_EQ(deviceDrawPacket->m_uniqueShaderResourceGroupCount, deviceDrawPacketClone->m_uniqueShaderResourceGroupCount);
  249. EXPECT_EQ(deviceDrawPacket->m_rootConstantSize, deviceDrawPacketClone->m_rootConstantSize);
  250. EXPECT_EQ(deviceDrawPacket->m_scissorsCount, deviceDrawPacketClone->m_scissorsCount);
  251. EXPECT_EQ(deviceDrawPacket->m_viewportsCount, deviceDrawPacketClone->m_viewportsCount);
  252. }
  253. const uint8_t drawItemCount =
  254. static_cast<uint8_t>(AZStd::min<size_t>(drawPacket->GetDrawItemCount(), MultiDeviceDrawPacketData::DrawItemCountMax));
  255. for (uint8_t i = 0; i < drawItemCount; ++i)
  256. {
  257. EXPECT_EQ(drawPacket->GetDrawListTag(i), drawPacketClone->GetDrawListTag(i));
  258. EXPECT_EQ(drawPacket->GetDrawFilterMask(i), drawPacketClone->GetDrawFilterMask(i));
  259. const auto* drawItem = drawPacket->GetDrawItem(i);
  260. const RHI::DrawItem* drawItemClone = drawPacketClone->GetDrawItem(i);
  261. // Check the clone is an actual copy not an identical pointer.
  262. EXPECT_NE(drawItem, drawItemClone);
  263. for (auto deviceIndex{ 0 }; deviceIndex < LocalDeviceCount; ++deviceIndex)
  264. {
  265. auto deviceDrawPacket{ drawPacket->GetDeviceDrawPacket(deviceIndex) };
  266. auto deviceDrawPacketClone{ drawPacketClone->GetDeviceDrawPacket(deviceIndex) };
  267. auto& deviceDrawItem{ drawItem->GetDeviceDrawItem(deviceIndex) };
  268. auto& deviceDrawItemClone{ drawItemClone->GetDeviceDrawItem(deviceIndex) };
  269. EXPECT_EQ(deviceDrawItem.m_arguments.m_type, deviceDrawItemClone.m_arguments.m_type);
  270. EXPECT_EQ(deviceDrawItem.m_pipelineState->GetType(), deviceDrawItemClone.m_pipelineState->GetType());
  271. EXPECT_EQ(deviceDrawItem.m_stencilRef, deviceDrawItemClone.m_stencilRef);
  272. EXPECT_EQ(deviceDrawItem.m_streamBufferViewCount, deviceDrawItemClone.m_streamBufferViewCount);
  273. EXPECT_EQ(deviceDrawItem.m_shaderResourceGroupCount, deviceDrawItemClone.m_shaderResourceGroupCount);
  274. EXPECT_EQ(deviceDrawItem.m_rootConstantSize, deviceDrawItemClone.m_rootConstantSize);
  275. EXPECT_EQ(deviceDrawItem.m_scissorsCount, deviceDrawItemClone.m_scissorsCount);
  276. EXPECT_EQ(deviceDrawItem.m_viewportsCount, deviceDrawItemClone.m_viewportsCount);
  277. uint8_t streamBufferViewCount = deviceDrawItem.m_streamBufferViewCount;
  278. uint8_t shaderResourceGroupCount = deviceDrawItem.m_shaderResourceGroupCount;
  279. uint8_t rootConstantSize = deviceDrawItem.m_rootConstantSize;
  280. uint8_t scissorsCount = deviceDrawItem.m_scissorsCount;
  281. uint8_t viewportsCount = deviceDrawItem.m_viewportsCount;
  282. for (uint8_t j = 0; j < streamBufferViewCount; ++j)
  283. {
  284. const RHI::DeviceStreamBufferView* streamBufferView = deviceDrawPacket->m_streamBufferViews + j;
  285. const RHI::DeviceStreamBufferView* streamBufferViewClone = deviceDrawPacketClone->m_streamBufferViews + j;
  286. EXPECT_EQ(streamBufferView->GetByteCount(), streamBufferViewClone->GetByteCount());
  287. EXPECT_EQ(streamBufferView->GetByteOffset(), streamBufferViewClone->GetByteOffset());
  288. EXPECT_EQ(streamBufferView->GetByteStride(), streamBufferViewClone->GetByteStride());
  289. EXPECT_EQ(streamBufferView->GetHash(), streamBufferViewClone->GetHash());
  290. }
  291. for (uint8_t j = 0; j < shaderResourceGroupCount; ++j)
  292. {
  293. EXPECT_EQ(*(deviceDrawItem.m_shaderResourceGroups + j), *(deviceDrawItemClone.m_shaderResourceGroups + j));
  294. }
  295. for (uint8_t j = 0; j < rootConstantSize; ++j)
  296. {
  297. EXPECT_EQ(*(deviceDrawItem.m_rootConstants + j), *(deviceDrawItemClone.m_rootConstants + j));
  298. }
  299. for (uint8_t j = 0; j < scissorsCount; ++j)
  300. {
  301. EXPECT_EQ(deviceDrawItem.m_scissors + j, deviceDrawItemClone.m_scissors + j);
  302. }
  303. for (uint8_t j = 0; j < viewportsCount; ++j)
  304. {
  305. EXPECT_EQ(deviceDrawItem.m_viewports + j, deviceDrawItemClone.m_viewports + j);
  306. }
  307. }
  308. }
  309. for (auto deviceIndex{ 0 }; deviceIndex < LocalDeviceCount; ++deviceIndex)
  310. {
  311. auto deviceDrawPacket{ drawPacket->GetDeviceDrawPacket(deviceIndex) };
  312. auto deviceDrawPacketClone{ drawPacketClone->GetDeviceDrawPacket(deviceIndex) };
  313. uint8_t streamBufferViewCount = deviceDrawPacket->m_streamBufferViewCount;
  314. uint8_t shaderResourceGroupCount = deviceDrawPacket->m_shaderResourceGroupCount;
  315. uint8_t uniqueShaderResourceGroupCount = deviceDrawPacket->m_uniqueShaderResourceGroupCount;
  316. uint8_t rootConstantSize = deviceDrawPacket->m_rootConstantSize;
  317. uint8_t scissorsCount = deviceDrawPacket->m_scissorsCount;
  318. uint8_t viewportsCount = deviceDrawPacket->m_viewportsCount;
  319. for (uint8_t i = 0; i < streamBufferViewCount; ++i)
  320. {
  321. const RHI::DeviceStreamBufferView* streamBufferView = deviceDrawPacket->m_streamBufferViews + i;
  322. const RHI::DeviceStreamBufferView* streamBufferViewClone = deviceDrawPacketClone->m_streamBufferViews + i;
  323. EXPECT_EQ(streamBufferView->GetByteCount(), streamBufferViewClone->GetByteCount());
  324. EXPECT_EQ(streamBufferView->GetByteOffset(), streamBufferViewClone->GetByteOffset());
  325. EXPECT_EQ(streamBufferView->GetByteStride(), streamBufferViewClone->GetByteStride());
  326. EXPECT_EQ(streamBufferView->GetHash(), streamBufferViewClone->GetHash());
  327. }
  328. for (uint8_t i = 0; i < shaderResourceGroupCount; ++i)
  329. {
  330. EXPECT_EQ(*(deviceDrawPacket->m_shaderResourceGroups + i), *(deviceDrawPacketClone->m_shaderResourceGroups + i));
  331. }
  332. for (uint8_t i = 0; i < uniqueShaderResourceGroupCount; ++i)
  333. {
  334. EXPECT_EQ(
  335. *(deviceDrawPacket->m_uniqueShaderResourceGroups + i), *(deviceDrawPacketClone->m_uniqueShaderResourceGroups + i));
  336. }
  337. for (uint8_t i = 0; i < rootConstantSize; ++i)
  338. {
  339. EXPECT_EQ(*(deviceDrawPacket->m_rootConstants + i), *(deviceDrawPacketClone->m_rootConstants + i));
  340. }
  341. for (uint8_t i = 0; i < scissorsCount; ++i)
  342. {
  343. EXPECT_EQ(deviceDrawPacket->m_scissors + i, deviceDrawPacketClone->m_scissors + i);
  344. }
  345. for (uint8_t i = 0; i < viewportsCount; ++i)
  346. {
  347. EXPECT_EQ(deviceDrawPacket->m_viewports + i, deviceDrawPacketClone->m_viewports + i);
  348. }
  349. }
  350. }
  351. void TestSetInstanceCount()
  352. {
  353. AZ::SimpleLcgRandom random(s_randomSeed);
  354. MultiDeviceDrawPacketData drawPacketData(random);
  355. RHI::DrawPacketBuilder builder(LocalDeviceMask);
  356. const auto drawPacket = drawPacketData.Build(builder);
  357. RHI::DrawPacketBuilder builder2(LocalDeviceMask);
  358. auto drawPacketClone = builder2.Clone(drawPacket.get());
  359. const uint8_t drawItemCount =
  360. static_cast<uint8_t>(AZStd::min<size_t>(drawPacket->GetDrawItemCount(), MultiDeviceDrawPacketData::DrawItemCountMax));
  361. // Test default value
  362. for (uint8_t i = 0; i < drawItemCount; ++i)
  363. {
  364. for (auto deviceIndex{ 0 }; deviceIndex < LocalDeviceCount; ++deviceIndex)
  365. {
  366. const auto& drawItemClone = drawPacketClone->m_drawItems[i].GetDeviceDrawItem(deviceIndex);
  367. EXPECT_EQ(drawItemClone.m_arguments.m_type, RHI::DrawType::Indexed);
  368. EXPECT_EQ(drawItemClone.m_arguments.m_indexed.m_instanceCount, 1);
  369. }
  370. }
  371. drawPacketClone->SetInstanceCount(12);
  372. for (uint8_t i = 0; i < drawItemCount; ++i)
  373. {
  374. for (auto deviceIndex{ 0 }; deviceIndex < LocalDeviceCount; ++deviceIndex)
  375. {
  376. const auto& drawItemClone = drawPacketClone->m_drawItems[i].GetDeviceDrawItem(deviceIndex);
  377. EXPECT_EQ(drawItemClone.m_arguments.m_indexed.m_instanceCount, 12);
  378. // Check that the original draw packet is not affected
  379. const auto& drawItem = drawPacket->m_drawItems[i].GetDeviceDrawItem(deviceIndex);
  380. EXPECT_EQ(drawItem.m_arguments.m_indexed.m_instanceCount, 1);
  381. }
  382. }
  383. }
  384. void TestSetRootConstants()
  385. {
  386. AZ::SimpleLcgRandom random(s_randomSeed);
  387. MultiDeviceDrawPacketData drawPacketData(random);
  388. RHI::DrawPacketBuilder builder(LocalDeviceMask);
  389. const auto drawPacket = drawPacketData.Build(builder);
  390. RHI::DrawPacketBuilder builder2(LocalDeviceMask);
  391. RHI::Ptr<RHI::DrawPacket> drawPacketClone = builder2.Clone(drawPacket.get());
  392. AZStd::vector<AZStd::array<uint8_t, sizeof(unsigned int) * 4>> rootConstantOld(LocalDeviceCount);
  393. for (auto deviceIndex{ 0 }; deviceIndex < LocalDeviceCount; ++deviceIndex)
  394. {
  395. auto deviceDrawPacketClone{ drawPacketClone->GetDeviceDrawPacket(deviceIndex) };
  396. EXPECT_EQ(sizeof(unsigned int) * 4, deviceDrawPacketClone->m_rootConstantSize);
  397. }
  398. // Keep a copy of old root constant for later verification
  399. for (auto deviceIndex{ 0 }; deviceIndex < LocalDeviceCount; ++deviceIndex)
  400. {
  401. auto deviceDrawPacketClone{ drawPacketClone->GetDeviceDrawPacket(deviceIndex) };
  402. for (uint8_t i = 0; i < deviceDrawPacketClone->m_rootConstantSize; ++i)
  403. {
  404. rootConstantOld[deviceIndex][i] = deviceDrawPacketClone->m_rootConstants[i];
  405. }
  406. }
  407. // Root constant data to be set, partial size as of the full root constant size.
  408. AZStd::array<uint8_t, sizeof(unsigned int)* 2> rootConstantNew = { 1, 2, 3, 4, 5, 6, 7, 8 };
  409. // Attempt to set beyond the array
  410. AZ_TEST_START_TRACE_SUPPRESSION;
  411. drawPacketClone->SetRootConstant(9, rootConstantNew);
  412. AZ_TEST_STOP_TRACE_SUPPRESSION(1);
  413. // Nothing will be set if it triggers the assert
  414. for (auto deviceIndex{ 0 }; deviceIndex < LocalDeviceCount; ++deviceIndex)
  415. {
  416. auto deviceDrawPacketClone{ drawPacketClone->GetDeviceDrawPacket(deviceIndex) };
  417. for (uint8_t i = 0; i < deviceDrawPacketClone->m_rootConstantSize; ++i)
  418. {
  419. EXPECT_EQ(rootConstantOld[deviceIndex][i], deviceDrawPacketClone->m_rootConstants[i]);
  420. }
  421. }
  422. drawPacketClone->SetRootConstant(8, rootConstantNew);
  423. for (auto deviceIndex{ 0 }; deviceIndex < LocalDeviceCount; ++deviceIndex)
  424. {
  425. auto deviceDrawPacket{ drawPacket->GetDeviceDrawPacket(deviceIndex) };
  426. auto deviceDrawPacketClone{ drawPacketClone->GetDeviceDrawPacket(deviceIndex) };
  427. // Compare the part staying the same.
  428. for (uint8_t i = 0; i < deviceDrawPacketClone->m_rootConstantSize - 8; ++i)
  429. {
  430. EXPECT_EQ(rootConstantOld[deviceIndex][i], deviceDrawPacketClone->m_rootConstants[i]);
  431. }
  432. // Compare the part being set
  433. for (uint8_t i = deviceDrawPacketClone->m_rootConstantSize - 8; i < deviceDrawPacketClone->m_rootConstantSize; ++i)
  434. {
  435. EXPECT_EQ(rootConstantNew[i - (deviceDrawPacketClone->m_rootConstantSize - 8)], deviceDrawPacketClone->m_rootConstants[i]);
  436. }
  437. // Compare the origin which shouldn't be affected.
  438. for (uint8_t i = 0; i < deviceDrawPacket->m_rootConstantSize; ++i)
  439. {
  440. EXPECT_EQ(rootConstantOld[deviceIndex][i], deviceDrawPacket->m_rootConstants[i]);
  441. }
  442. }
  443. }
  444. };
  445. TEST_F(MultiDeviceDrawPacketTest, DrawPacketEmpty)
  446. {
  447. DrawPacketEmpty();
  448. }
  449. TEST_F(MultiDeviceDrawPacketTest, DrawPacketNullItem)
  450. {
  451. DrawPacketNullItem();
  452. }
  453. TEST_F(MultiDeviceDrawPacketTest, DrawPacketBuild)
  454. {
  455. DrawPacketBuild();
  456. }
  457. TEST_F(MultiDeviceDrawPacketTest, DrawPacketBuildClearBuildNull)
  458. {
  459. DrawPacketBuildClearBuildNull();
  460. }
  461. TEST_F(MultiDeviceDrawPacketTest, DrawListContextFilter)
  462. {
  463. DrawListContextFilter();
  464. }
  465. TEST_F(MultiDeviceDrawPacketTest, DrawListContextNullFilter)
  466. {
  467. DrawListContextNullFilter();
  468. }
  469. TEST_F(MultiDeviceDrawPacketTest, DrawPacketClone)
  470. {
  471. DrawPacketClone();
  472. }
  473. TEST_F(MultiDeviceDrawPacketTest, TestSetInstanceCount)
  474. {
  475. TestSetInstanceCount();
  476. }
  477. TEST_F(MultiDeviceDrawPacketTest, TestSetRootConstants)
  478. {
  479. TestSetRootConstants();
  480. }
  481. } // namespace UnitTest