FrameSchedulerTests.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include "RHITestFixture.h"
  9. #include <Tests/Factory.h>
  10. #include <Tests/Device.h>
  11. #include <Atom/RHI/ScopeProducer.h>
  12. #include <Atom/RHI/FrameScheduler.h>
  13. #include <AzCore/Math/Random.h>
  14. #include <Atom/RHI/BufferPool.h>
  15. #include <Atom/RHI/ImagePool.h>
  16. #include <Atom/RHI/RHISystemInterface.h>
  17. namespace UnitTest
  18. {
  19. using namespace AZ;
  20. struct ImportedImage
  21. {
  22. RHI::AttachmentId m_id;
  23. RHI::Ptr<RHI::Image> m_image;
  24. };
  25. struct ImportedBuffer
  26. {
  27. RHI::AttachmentId m_id;
  28. RHI::Ptr<RHI::Buffer> m_buffer;
  29. };
  30. struct TransientImage
  31. {
  32. RHI::AttachmentId m_id;
  33. RHI::ImageDescriptor m_descriptor;
  34. };
  35. struct TransientBuffer
  36. {
  37. RHI::AttachmentId m_id;
  38. RHI::BufferDescriptor m_descriptor;
  39. };
  40. class ScopeProducer
  41. : public RHI::ScopeProducer
  42. {
  43. public:
  44. AZ_CLASS_ALLOCATOR(ScopeProducer, SystemAllocator);
  45. ScopeProducer(const RHI::ScopeId& scopeId)
  46. : RHI::ScopeProducer(scopeId)
  47. {}
  48. void SetupFrameGraphDependencies(RHI::FrameGraphInterface frameGraph) override
  49. {
  50. RHI::FrameGraphAttachmentInterface attachmentDatabase = frameGraph.GetAttachmentDatabase();
  51. for (ImportedImage& image : m_imageImports)
  52. {
  53. ASSERT_FALSE(attachmentDatabase.IsAttachmentValid(image.m_id));
  54. attachmentDatabase.ImportImage(image.m_id, image.m_image);
  55. ASSERT_TRUE(attachmentDatabase.IsAttachmentValid(image.m_id));
  56. }
  57. for (ImportedBuffer& buffer : m_bufferImports)
  58. {
  59. ASSERT_FALSE(attachmentDatabase.IsAttachmentValid(buffer.m_id));
  60. attachmentDatabase.ImportBuffer(buffer.m_id, buffer.m_buffer);
  61. ASSERT_TRUE(attachmentDatabase.IsAttachmentValid(buffer.m_id));
  62. }
  63. for (const TransientImage& image : m_transientImages)
  64. {
  65. ASSERT_FALSE(attachmentDatabase.IsAttachmentValid(image.m_id));
  66. attachmentDatabase.CreateTransientImage(RHI::TransientImageDescriptor{image.m_id, image.m_descriptor});
  67. ASSERT_TRUE(attachmentDatabase.IsAttachmentValid(image.m_id));
  68. }
  69. for (const TransientBuffer& buffer : m_transientBuffers)
  70. {
  71. ASSERT_FALSE(attachmentDatabase.IsAttachmentValid(buffer.m_id));
  72. attachmentDatabase.CreateTransientBuffer(RHI::TransientBufferDescriptor{buffer.m_id, buffer.m_descriptor});
  73. ASSERT_TRUE(attachmentDatabase.IsAttachmentValid(buffer.m_id));
  74. }
  75. for (const ImageUsage& usage : m_imageUsages)
  76. {
  77. frameGraph.UseShaderAttachment(usage.m_descriptor, usage.m_access, RHI::ScopeAttachmentStage::AnyGraphics);
  78. }
  79. for (const BufferUsage& usage : m_bufferUsages)
  80. {
  81. frameGraph.UseShaderAttachment(usage.m_descriptor, usage.m_access, RHI::ScopeAttachmentStage::AnyGraphics);
  82. }
  83. frameGraph.SetEstimatedItemCount(0);
  84. }
  85. void CompileResources(const RHI::FrameGraphCompileContext& context) override
  86. {
  87. ASSERT_TRUE(context.GetScopeId() == GetScopeId());
  88. for (const ImageUsage& usage : m_imageUsages)
  89. {
  90. ASSERT_TRUE(context.GetImageView(usage.m_descriptor.m_attachmentId) != nullptr);
  91. }
  92. for (const BufferUsage& usage : m_bufferUsages)
  93. {
  94. ASSERT_TRUE(context.GetBufferView(usage.m_descriptor.m_attachmentId) != nullptr);
  95. }
  96. }
  97. void BuildCommandList(const RHI::FrameGraphExecuteContext& context) override
  98. {
  99. ASSERT_TRUE(context.GetScopeId() == GetScopeId());
  100. ASSERT_TRUE(context.GetCommandListIndex() == 0);
  101. ASSERT_TRUE(context.GetCommandListCount() == 1);
  102. }
  103. AZStd::vector<ImportedImage> m_imageImports;
  104. AZStd::vector<ImportedBuffer> m_bufferImports;
  105. AZStd::vector<TransientImage> m_transientImages;
  106. AZStd::vector<TransientBuffer> m_transientBuffers;
  107. struct ImageUsage
  108. {
  109. RHI::ImageScopeAttachmentDescriptor m_descriptor;
  110. RHI::ScopeAttachmentAccess m_access;
  111. };
  112. struct BufferUsage
  113. {
  114. RHI::BufferScopeAttachmentDescriptor m_descriptor;
  115. RHI::ScopeAttachmentAccess m_access;
  116. };
  117. AZStd::vector<ImageUsage> m_imageUsages;
  118. AZStd::vector<BufferUsage> m_bufferUsages;
  119. };
  120. class FrameSchedulerTests
  121. : public RHITestFixture
  122. {
  123. public:
  124. FrameSchedulerTests()
  125. : RHITestFixture()
  126. {
  127. }
  128. void SetUp() override
  129. {
  130. UnitTest::RHITestFixture::SetUp();
  131. m_rootFactory.reset(aznew Factory());
  132. m_rhiSystem.reset(aznew AZ::RHI::RHISystem);
  133. m_rhiSystem->InitDevices();
  134. m_rhiSystem->Init();
  135. m_device = AZ::RHI::RHISystemInterface::Get()->GetDevice(RHI::MultiDevice::DefaultDeviceIndex);
  136. m_state.reset(new State);
  137. {
  138. m_state->m_bufferPool = aznew RHI::BufferPool;
  139. RHI::BufferPoolDescriptor desc;
  140. desc.m_bindFlags = RHI::BufferBindFlags::ShaderReadWrite;
  141. m_state->m_bufferPool->Init(RHI::MultiDevice::DefaultDevice, desc);
  142. }
  143. for (uint32_t i = 0; i < ImportedBufferCount; ++i)
  144. {
  145. RHI::Ptr<RHI::Buffer> buffer;
  146. buffer = aznew RHI::Buffer;
  147. RHI::BufferDescriptor desc;
  148. desc.m_bindFlags = RHI::BufferBindFlags::ShaderReadWrite;
  149. desc.m_byteCount = BufferSize;
  150. RHI::BufferInitRequest request;
  151. request.m_descriptor = desc;
  152. request.m_buffer = buffer.get();
  153. m_state->m_bufferPool->InitBuffer(request);
  154. m_state->m_bufferAttachments[i].m_id = RHI::AttachmentId{AZStd::string::format("B%d", i)};
  155. m_state->m_bufferAttachments[i].m_buffer = AZStd::move(buffer);
  156. }
  157. {
  158. m_state->m_imagePool = aznew RHI::ImagePool();
  159. RHI::ImagePoolDescriptor desc;
  160. desc.m_bindFlags = RHI::ImageBindFlags::ShaderReadWrite;
  161. m_state->m_imagePool->Init(RHI::MultiDevice::AllDevices, desc);
  162. }
  163. for (uint32_t i = 0; i < ImportedImageCount; ++i)
  164. {
  165. RHI::Ptr<RHI::Image> image;
  166. image = aznew RHI::Image();
  167. RHI::ImageDescriptor desc = RHI::ImageDescriptor::Create2D(
  168. RHI::ImageBindFlags::ShaderReadWrite,
  169. ImageSize,
  170. ImageSize,
  171. RHI::Format::R8G8B8A8_UNORM);
  172. RHI::ImageInitRequest request;
  173. request.m_descriptor = desc;
  174. request.m_image = image.get();
  175. m_state->m_imagePool->InitImage(request);
  176. m_state->m_imageAttachments[i].m_id = RHI::AttachmentId{AZStd::string::format("I%d", i)};
  177. m_state->m_imageAttachments[i].m_image = AZStd::move(image);
  178. }
  179. for (uint32_t i = 0; i < ScopeCount; ++i)
  180. {
  181. m_state->m_producers.emplace_back(aznew ScopeProducer(RHI::ScopeId{AZStd::string::format("S%d", i)}));
  182. }
  183. }
  184. void TearDown() override
  185. {
  186. m_state.reset();
  187. m_device = nullptr;
  188. m_rhiSystem->Shutdown();
  189. m_rhiSystem.reset();
  190. m_rootFactory.reset();
  191. RHITestFixture::TearDown();
  192. }
  193. void Test()
  194. {
  195. RHI::FrameScheduler frameScheduler;
  196. RHI::FrameSchedulerDescriptor descriptor;
  197. descriptor.m_transientAttachmentPoolDescriptors[RHI::MultiDevice::DefaultDeviceIndex].m_bufferBudgetInBytes = 80 * 1024 * 1024;
  198. frameScheduler.Init(RHI::MultiDevice::DefaultDevice, descriptor);
  199. RHI::ImageScopeAttachmentDescriptor imageBindingDescs[2];
  200. imageBindingDescs[0].m_imageViewDescriptor = RHI::ImageViewDescriptor();
  201. imageBindingDescs[0].m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Clear;
  202. imageBindingDescs[0].m_loadStoreAction.m_clearValue = RHI::ClearValue::CreateVector4Float(1.0f, 0.0, 0.0, 0.0);
  203. imageBindingDescs[1] = imageBindingDescs[0];
  204. imageBindingDescs[1].m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Load;
  205. RHI::BufferScopeAttachmentDescriptor bufferBindingDescs[2];
  206. bufferBindingDescs[0].m_bufferViewDescriptor = RHI::BufferViewDescriptor::CreateRaw(0, BufferSize);
  207. bufferBindingDescs[0].m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Clear;
  208. bufferBindingDescs[0].m_loadStoreAction.m_clearValue = RHI::ClearValue::CreateVector4Float(1.0f, 0.0, 0.0, 0.0);
  209. bufferBindingDescs[1] = bufferBindingDescs[0];
  210. bufferBindingDescs[1].m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Load;
  211. AZ::SimpleLcgRandom random;
  212. struct Interval
  213. {
  214. uint32_t m_begin;
  215. uint32_t m_end;
  216. };
  217. Interval bufferScopeIntervals[BufferCount];
  218. for (uint32_t i = 0; i < BufferCount; ++i)
  219. {
  220. uint32_t b = random.GetRandom() % ScopeCount;
  221. uint32_t e = random.GetRandom() % ScopeCount;
  222. if (b > e)
  223. {
  224. AZStd::swap(b, e);
  225. }
  226. bufferScopeIntervals[i].m_begin = b;
  227. bufferScopeIntervals[i].m_end = e;
  228. }
  229. Interval imageScopeIntervals[ImageCount];
  230. for (uint32_t i = 0; i < ImageCount; ++i)
  231. {
  232. uint32_t b = random.GetRandom() % ScopeCount;
  233. uint32_t e = random.GetRandom() % ScopeCount;
  234. if (b > e)
  235. {
  236. AZStd::swap(b, e);
  237. }
  238. imageScopeIntervals[i].m_begin = b;
  239. imageScopeIntervals[i].m_end = e;
  240. }
  241. for (uint32_t scopeIdx = 0; scopeIdx < ScopeCount; ++scopeIdx)
  242. {
  243. ScopeProducer& producer = *m_state->m_producers[scopeIdx];
  244. //
  245. // IMPORTS
  246. //
  247. for (uint32_t i = 0; i < ImportedBufferCount; ++i)
  248. {
  249. if (scopeIdx == bufferScopeIntervals[i].m_begin)
  250. {
  251. producer.m_bufferImports.push_back(m_state->m_bufferAttachments[i]);
  252. bufferBindingDescs[0].m_attachmentId = m_state->m_bufferAttachments[i].m_id;
  253. producer.m_bufferUsages.push_back(ScopeProducer::BufferUsage{ bufferBindingDescs[0], RHI::ScopeAttachmentAccess::ReadWrite });
  254. }
  255. else if (scopeIdx == bufferScopeIntervals[i].m_end)
  256. {
  257. bufferBindingDescs[1].m_attachmentId = m_state->m_bufferAttachments[i].m_id;
  258. producer.m_bufferUsages.push_back(ScopeProducer::BufferUsage{ bufferBindingDescs[1], RHI::ScopeAttachmentAccess::Read });
  259. }
  260. }
  261. for (uint32_t i = 0; i < ImportedImageCount; ++i)
  262. {
  263. if (scopeIdx == imageScopeIntervals[i].m_begin)
  264. {
  265. producer.m_imageImports.push_back(m_state->m_imageAttachments[i]);
  266. imageBindingDescs[0].m_attachmentId = m_state->m_imageAttachments[i].m_id;
  267. producer.m_imageUsages.push_back(ScopeProducer::ImageUsage{ imageBindingDescs[0], RHI::ScopeAttachmentAccess::ReadWrite });
  268. }
  269. else if (scopeIdx == imageScopeIntervals[i].m_end)
  270. {
  271. imageBindingDescs[1].m_attachmentId = m_state->m_imageAttachments[i].m_id;
  272. producer.m_imageUsages.push_back(ScopeProducer::ImageUsage{ imageBindingDescs[1], RHI::ScopeAttachmentAccess::Read });
  273. }
  274. }
  275. //
  276. // TRANSIENTS
  277. //
  278. for (uint32_t i = 0; i < TransientBufferCount; ++i)
  279. {
  280. const uint32_t adjustedIndex = i + ImportedBufferCount;
  281. TransientBuffer transientBuffer =
  282. {
  283. RHI::AttachmentId{AZStd::string::format("B%d", adjustedIndex)},
  284. RHI::BufferDescriptor(RHI::BufferBindFlags::ShaderReadWrite, BufferSize)
  285. };
  286. bufferBindingDescs[0].m_attachmentId = transientBuffer.m_id;
  287. bufferBindingDescs[1].m_attachmentId = transientBuffer.m_id;
  288. if (scopeIdx == bufferScopeIntervals[adjustedIndex].m_begin)
  289. {
  290. producer.m_transientBuffers.push_back(transientBuffer);
  291. producer.m_bufferUsages.push_back(ScopeProducer::BufferUsage{ bufferBindingDescs[0], RHI::ScopeAttachmentAccess::ReadWrite });
  292. }
  293. else if (scopeIdx == bufferScopeIntervals[adjustedIndex].m_end)
  294. {
  295. producer.m_bufferUsages.push_back(ScopeProducer::BufferUsage{ bufferBindingDescs[1], RHI::ScopeAttachmentAccess::Read });
  296. }
  297. }
  298. for (uint32_t i = 0; i < TransientImageCount; ++i)
  299. {
  300. const uint32_t adjustedIndex = i + ImportedImageCount;
  301. TransientImage transientImage =
  302. {
  303. RHI::AttachmentId{AZStd::string::format("I%d", adjustedIndex)},
  304. RHI::ImageDescriptor::Create2D(RHI::ImageBindFlags::ShaderReadWrite, ImageSize, ImageSize, RHI::Format::R8G8B8A8_UNORM)
  305. };
  306. imageBindingDescs[0].m_attachmentId = transientImage.m_id;
  307. imageBindingDescs[1].m_attachmentId = transientImage.m_id;
  308. if (scopeIdx == imageScopeIntervals[adjustedIndex].m_begin)
  309. {
  310. producer.m_transientImages.push_back(transientImage);
  311. producer.m_imageUsages.push_back(ScopeProducer::ImageUsage{ imageBindingDescs[0], RHI::ScopeAttachmentAccess::ReadWrite });
  312. }
  313. else if (scopeIdx == imageScopeIntervals[adjustedIndex].m_end)
  314. {
  315. producer.m_imageUsages.push_back(ScopeProducer::ImageUsage{ imageBindingDescs[1], RHI::ScopeAttachmentAccess::Read });
  316. }
  317. }
  318. }
  319. for (uint32_t frameIdx = 0; frameIdx < FrameIterationCount; ++frameIdx)
  320. {
  321. frameScheduler.BeginFrame();
  322. for (AZStd::unique_ptr<ScopeProducer>& producer : m_state->m_producers)
  323. {
  324. frameScheduler.ImportScopeProducer(*producer);
  325. }
  326. RHI::FrameSchedulerCompileRequest compileRequest;
  327. compileRequest.m_jobPolicy = RHI::JobPolicy::Serial;
  328. frameScheduler.Compile(compileRequest);
  329. frameScheduler.Execute(RHI::JobPolicy::Serial);
  330. frameScheduler.EndFrame();
  331. }
  332. frameScheduler.Shutdown();
  333. }
  334. private:
  335. static const uint32_t FrameIterationCount = 128;
  336. static const uint32_t ImportedImageCount = 16;
  337. static const uint32_t ImportedBufferCount = 16;
  338. static const uint32_t TransientBufferCount = 16;
  339. static const uint32_t TransientImageCount = 16;
  340. static const uint32_t BufferCount = ImportedBufferCount + TransientBufferCount;
  341. static const uint32_t ImageCount = ImportedImageCount + TransientImageCount;
  342. static const uint32_t BufferSize = 64;
  343. static const uint32_t ImageSize = 16;
  344. static const uint32_t ScopeCount = 16;
  345. AZStd::unique_ptr<Factory> m_rootFactory;
  346. AZStd::unique_ptr<AZ::RHI::RHISystem> m_rhiSystem; //! Needed for the TransientAttachmentPool in the FrameScheduler
  347. RHI::Ptr<RHI::Device> m_device;
  348. struct State
  349. {
  350. RHI::Ptr<RHI::BufferPool> m_bufferPool;
  351. RHI::Ptr<RHI::ImagePool> m_imagePool;
  352. ImportedImage m_imageAttachments[ImportedImageCount];
  353. ImportedBuffer m_bufferAttachments[ImportedBufferCount];
  354. AZStd::vector<AZStd::unique_ptr<ScopeProducer>> m_producers;
  355. };
  356. AZStd::unique_ptr<State> m_state;
  357. };
  358. TEST_F(FrameSchedulerTests, Test)
  359. {
  360. Test();
  361. }
  362. }