GrWorkGraphs.cpp 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948
  1. // Copyright (C) 2009-present, Panagiotis Christopoulos Charitos and contributors.
  2. // All rights reserved.
  3. // Code licensed under the BSD License.
  4. // http://www.anki3d.org/LICENSE
  5. #include <Tests/Framework/Framework.h>
  6. #include <Tests/Gr/GrCommon.h>
  7. #include <AnKi/Util/HighRezTimer.h>
  8. #include <AnKi/Gr.h>
  9. using namespace anki;
  10. static void clearSwapchain(CommandBufferPtr cmdb = CommandBufferPtr())
  11. {
  12. const Bool continueCmdb = cmdb.isCreated();
  13. TexturePtr presentTex = GrManager::getSingleton().acquireNextPresentableTexture();
  14. if(!continueCmdb)
  15. {
  16. CommandBufferInitInfo cinit;
  17. cinit.m_flags = CommandBufferFlag::kGeneralWork | CommandBufferFlag::kSmallBatch;
  18. cmdb = GrManager::getSingleton().newCommandBuffer(cinit);
  19. }
  20. const TextureBarrierInfo barrier = {TextureView(presentTex.get(), TextureSubresourceDesc::all()), TextureUsageBit::kNone,
  21. TextureUsageBit::kRtvDsvWrite};
  22. cmdb->setPipelineBarrier({&barrier, 1}, {}, {});
  23. RenderTarget rt;
  24. rt.m_textureView = TextureView(presentTex.get(), TextureSubresourceDesc::all());
  25. rt.m_clearValue.m_colorf = {1.0f, F32(rand()) / F32(RAND_MAX), 1.0f, 1.0f};
  26. cmdb->beginRenderPass({rt});
  27. cmdb->endRenderPass();
  28. const TextureBarrierInfo barrier2 = {TextureView(presentTex.get(), TextureSubresourceDesc::all()), TextureUsageBit::kRtvDsvWrite,
  29. TextureUsageBit::kPresent};
  30. cmdb->setPipelineBarrier({&barrier2, 1}, {}, {});
  31. if(!continueCmdb)
  32. {
  33. cmdb->endRecording();
  34. GrManager::getSingleton().submit(cmdb.get());
  35. }
  36. }
  37. template<typename TFunc>
  38. static void runBenchmark(U32 iterationCount, U32 iterationsPerCommandBuffer, Bool bBenchmark, TFunc func)
  39. {
  40. ANKI_ASSERT(iterationCount >= iterationsPerCommandBuffer && (iterationCount % iterationsPerCommandBuffer) == 0);
  41. FencePtr fence;
  42. F64 avgCpuTimePerIterationMs = 0.0;
  43. DynamicArray<TimestampQueryPtr> timestamps;
  44. const U32 commandBufferCount = iterationCount / iterationsPerCommandBuffer;
  45. for(U32 icmdb = 0; icmdb < commandBufferCount; ++icmdb)
  46. {
  47. CommandBufferPtr cmdb = GrManager::getSingleton().newCommandBuffer(CommandBufferInitInfo(CommandBufferFlag::kGeneralWork));
  48. TimestampQueryPtr query1 = GrManager::getSingleton().newTimestampQuery();
  49. cmdb->writeTimestamp(query1.get());
  50. timestamps.emplaceBack(query1);
  51. const U64 cpuTimeStart = HighRezTimer::getCurrentTimeUs();
  52. for(U32 i = 0; i < iterationsPerCommandBuffer; ++i)
  53. {
  54. func(*cmdb);
  55. }
  56. // clearSwapchain(cmdb);
  57. TimestampQueryPtr query2 = GrManager::getSingleton().newTimestampQuery();
  58. cmdb->writeTimestamp(query2.get());
  59. timestamps.emplaceBack(query2);
  60. cmdb->endRecording();
  61. const U64 cpuTimeEnd = HighRezTimer::getCurrentTimeUs();
  62. avgCpuTimePerIterationMs += (Second(cpuTimeEnd - cpuTimeStart) * 0.001) / Second(iterationCount);
  63. GrManager::getSingleton().submit(cmdb.get(), {}, (icmdb == commandBufferCount - 1) ? &fence : nullptr);
  64. // GrManager::getSingleton().swapBuffers();
  65. }
  66. const Bool done = fence->clientWait(kMaxSecond);
  67. ANKI_TEST_EXPECT_EQ(done, true);
  68. F64 avgTimePerIterationMs = 0.0f;
  69. for(U32 i = 0; i < timestamps.getSize(); i += 2)
  70. {
  71. Second a, b;
  72. ANKI_TEST_EXPECT_EQ(timestamps[i]->getResult(a), TimestampQueryResult::kAvailable);
  73. ANKI_TEST_EXPECT_EQ(timestamps[i + 1]->getResult(b), TimestampQueryResult::kAvailable);
  74. avgTimePerIterationMs += (Second(b - a) * 1000.0) / Second(iterationCount);
  75. }
  76. if(bBenchmark)
  77. {
  78. ANKI_TEST_LOGI("Benchmark: avg GPU time: %fms, avg CPU time: %fms", avgTimePerIterationMs, avgCpuTimePerIterationMs);
  79. }
  80. }
  81. void commonInitWg(Bool& bBenchmark, Bool& bWorkgraphs)
  82. {
  83. bBenchmark = getenv("BENCHMARK") && CString(getenv("BENCHMARK")) == "1";
  84. [[maybe_unused]] Error err = CVarSet::getSingleton().setMultiple(Array<const Char*, 2>{"WorkGraphs", "1"});
  85. commonInit(!bBenchmark);
  86. bWorkgraphs = getenv("WORKGRAPHS") && CString(getenv("WORKGRAPHS")) == "1" && GrManager::getSingleton().getDeviceCapabilities().m_workGraphs;
  87. ANKI_TEST_LOGI("Testing with BENCHMARK=%u WORKGRAPHS=%u", bBenchmark, bWorkgraphs);
  88. }
  89. ANKI_TEST(Gr, WorkGraphHelloWorld)
  90. {
  91. // CVarSet::getSingleton().setMultiple(Array<const Char*, 2>{"Device", "1"});
  92. commonInit();
  93. {
  94. const Char* kSrc = R"(
  95. struct FirstNodeRecord
  96. {
  97. uint3 m_gridSize : SV_DispatchGrid;
  98. uint m_value;
  99. };
  100. struct SecondNodeRecord
  101. {
  102. uint3 m_gridSize : SV_DispatchGrid;
  103. uint m_value;
  104. };
  105. struct ThirdNodeRecord
  106. {
  107. uint m_value;
  108. };
  109. RWStructuredBuffer<uint> g_buff : register(u0);
  110. [Shader("node")] [NodeLaunch("broadcasting")] [NodeIsProgramEntry] [NodeMaxDispatchGrid(1, 1, 1)] [NumThreads(16, 1, 1)]
  111. void main(DispatchNodeInputRecord<FirstNodeRecord> inp, [MaxRecords(2)] NodeOutput<SecondNodeRecord> secondNode, uint svGroupIndex : SV_GroupIndex)
  112. {
  113. GroupNodeOutputRecords<SecondNodeRecord> rec = secondNode.GetGroupNodeOutputRecords(2);
  114. if(svGroupIndex < 2)
  115. {
  116. rec[svGroupIndex].m_gridSize = uint3(16, 1, 1);
  117. rec[svGroupIndex].m_value = inp.Get().m_value;
  118. }
  119. rec.OutputComplete();
  120. }
  121. [Shader("node")] [NodeLaunch("broadcasting")] [NumThreads(16, 1, 1)] [NodeMaxDispatchGrid(16, 1, 1)]
  122. void secondNode(DispatchNodeInputRecord<SecondNodeRecord> inp, [MaxRecords(32)] NodeOutput<ThirdNodeRecord> thirdNode,
  123. uint svGroupIndex : SV_GROUPINDEX)
  124. {
  125. GroupNodeOutputRecords<ThirdNodeRecord> recs = thirdNode.GetGroupNodeOutputRecords(32);
  126. recs[svGroupIndex * 2 + 0].m_value = inp.Get().m_value;
  127. recs[svGroupIndex * 2 + 1].m_value = inp.Get().m_value;
  128. recs.OutputComplete();
  129. }
  130. [Shader("node")] [NodeLaunch("coalescing")] [NumThreads(16, 1, 1)]
  131. void thirdNode([MaxRecords(32)] GroupNodeInputRecords<ThirdNodeRecord> inp, uint svGroupIndex : SV_GroupIndex)
  132. {
  133. if (svGroupIndex * 2 < inp.Count())
  134. InterlockedAdd(g_buff[0], inp[svGroupIndex * 2].m_value);
  135. if (svGroupIndex * 2 + 1 < inp.Count())
  136. InterlockedAdd(g_buff[0], inp[svGroupIndex * 2 + 1].m_value);
  137. }
  138. )";
  139. ShaderPtr shader = createShader(kSrc, ShaderType::kWorkGraph);
  140. ShaderProgramInitInfo progInit;
  141. progInit.m_workGraph.m_shader = shader.get();
  142. WorkGraphNodeSpecialization wgSpecialization = {"main", UVec3(4, 1, 1)};
  143. progInit.m_workGraph.m_nodeSpecializations = ConstWeakArray<WorkGraphNodeSpecialization>(&wgSpecialization, 1);
  144. ShaderProgramPtr prog = GrManager::getSingleton().newShaderProgram(progInit);
  145. BufferPtr counterBuff = createBuffer(BufferUsageBit::kAllUav | BufferUsageBit::kCopySource, 0u, 1, "CounterBuffer");
  146. BufferInitInfo scratchInit("scratch");
  147. scratchInit.m_size = prog->getWorkGraphMemoryRequirements();
  148. scratchInit.m_usage = BufferUsageBit::kAllUav;
  149. BufferPtr scratchBuff = GrManager::getSingleton().newBuffer(scratchInit);
  150. struct FirstNodeRecord
  151. {
  152. UVec3 m_gridSize;
  153. U32 m_value;
  154. };
  155. Array<FirstNodeRecord, 2> records;
  156. for(U32 i = 0; i < records.getSize(); ++i)
  157. {
  158. records[i].m_gridSize = UVec3(4, 1, 1);
  159. records[i].m_value = (i + 1) * 10;
  160. }
  161. CommandBufferPtr cmdb = GrManager::getSingleton().newCommandBuffer(CommandBufferInitInfo(CommandBufferFlag::kSmallBatch));
  162. cmdb->bindShaderProgram(prog.get());
  163. cmdb->bindUav(0, 0, BufferView(counterBuff.get()));
  164. cmdb->dispatchGraph(BufferView(scratchBuff.get()), records.getBegin(), records.getSize(), sizeof(records[0]));
  165. cmdb->endRecording();
  166. FencePtr fence;
  167. GrManager::getSingleton().submit(cmdb.get(), {}, &fence);
  168. fence->clientWait(kMaxSecond);
  169. validateBuffer(counterBuff, ConstWeakArray(Array<U32, 1>{122880}));
  170. }
  171. commonDestroy();
  172. }
  173. ANKI_TEST(Gr, WorkGraphAmplification)
  174. {
  175. // CVarSet::getSingleton().setMultiple(Array<const Char*, 2>{"Device", "2"});
  176. Bool bBenchmark, bWorkgraphs;
  177. commonInitWg(bBenchmark, bWorkgraphs);
  178. {
  179. const Char* kSrc = R"(
  180. struct FirstNodeRecord
  181. {
  182. uint3 m_dispatchGrid : SV_DispatchGrid;
  183. };
  184. struct SecondNodeRecord
  185. {
  186. uint3 m_dispatchGrid : SV_DispatchGrid;
  187. uint m_objectIndex;
  188. };
  189. struct Aabb
  190. {
  191. uint m_min;
  192. uint m_max;
  193. };
  194. struct Object
  195. {
  196. uint m_positionsStart; // Points to g_positions
  197. uint m_positionCount;
  198. };
  199. RWStructuredBuffer<Aabb> g_aabbs : register(u0);
  200. StructuredBuffer<Object> g_objects : register(t0);
  201. StructuredBuffer<uint> g_positions : register(t1);
  202. #define THREAD_COUNT 64u
  203. // Operates per object
  204. [Shader("node")] [NodeLaunch("broadcasting")] [NodeIsProgramEntry] [NodeMaxDispatchGrid(1, 1, 1)] [NumThreads(THREAD_COUNT, 1, 1)]
  205. void main(DispatchNodeInputRecord<FirstNodeRecord> inp, [MaxRecords(THREAD_COUNT)] NodeOutput<SecondNodeRecord> computeAabb,
  206. uint svGroupIndex : SV_GroupIndex, uint svDispatchThreadId : SV_DispatchThreadId)
  207. {
  208. GroupNodeOutputRecords<SecondNodeRecord> recs = computeAabb.GetGroupNodeOutputRecords(THREAD_COUNT);
  209. const Object obj = g_objects[svDispatchThreadId];
  210. recs[svGroupIndex].m_objectIndex = svDispatchThreadId;
  211. recs[svGroupIndex].m_dispatchGrid = uint3((obj.m_positionCount + (THREAD_COUNT - 1)) / THREAD_COUNT, 1, 1);
  212. recs.OutputComplete();
  213. }
  214. groupshared Aabb g_aabb;
  215. // Operates per position
  216. [Shader("node")] [NodeLaunch("broadcasting")] [NodeMaxDispatchGrid(1, 1, 1)] [NumThreads(THREAD_COUNT, 1, 1)]
  217. void computeAabb(DispatchNodeInputRecord<SecondNodeRecord> inp, uint svDispatchThreadId : SV_DispatchThreadId, uint svGroupIndex : SV_GroupIndex)
  218. {
  219. const Object obj = g_objects[inp.Get().m_objectIndex];
  220. svDispatchThreadId = min(svDispatchThreadId, obj.m_positionCount - 1);
  221. if(svGroupIndex == 0)
  222. {
  223. g_aabb.m_min = 0xFFFFFFFF;
  224. g_aabb.m_max = 0;
  225. }
  226. Barrier(GROUP_SHARED_MEMORY, GROUP_SCOPE | GROUP_SYNC);
  227. const uint positionIndex = obj.m_positionsStart + svDispatchThreadId;
  228. const uint pos = g_positions[positionIndex];
  229. InterlockedMin(g_aabb.m_min, pos);
  230. InterlockedMax(g_aabb.m_max, pos);
  231. Barrier(GROUP_SHARED_MEMORY, GROUP_SCOPE | GROUP_SYNC);
  232. InterlockedMin(g_aabbs[inp.Get().m_objectIndex].m_min, g_aabb.m_min);
  233. InterlockedMax(g_aabbs[inp.Get().m_objectIndex].m_max, g_aabb.m_max);
  234. }
  235. )";
  236. const Char* kComputeSrc = R"(
  237. struct Aabb
  238. {
  239. uint m_min;
  240. uint m_max;
  241. };
  242. struct Object
  243. {
  244. uint m_positionsStart; // Points to g_positions
  245. uint m_positionCount;
  246. };
  247. struct PushConsts
  248. {
  249. uint m_objectIndex;
  250. uint m_padding1;
  251. uint m_padding2;
  252. uint m_padding3;
  253. };
  254. RWStructuredBuffer<Aabb> g_aabbs : register(u0);
  255. StructuredBuffer<Object> g_objects : register(t0);
  256. StructuredBuffer<uint> g_positions : register(t1);
  257. #if defined(__spirv__)
  258. [[vk::push_constant]] ConstantBuffer<PushConsts> g_consts;
  259. #else
  260. ConstantBuffer<PushConsts> g_consts : register(b0, space3000);
  261. #endif
  262. #define THREAD_COUNT 64u
  263. groupshared Aabb g_aabb;
  264. [NumThreads(THREAD_COUNT, 1, 1)]
  265. void main(uint svDispatchThreadId : SV_DispatchThreadId, uint svGroupIndex : SV_GroupIndex)
  266. {
  267. const Object obj = g_objects[g_consts.m_objectIndex];
  268. svDispatchThreadId = min(svDispatchThreadId, obj.m_positionCount - 1);
  269. if(svGroupIndex == 0)
  270. {
  271. g_aabb.m_min = 0xFFFFFFFF;
  272. g_aabb.m_max = 0;
  273. }
  274. Barrier(GROUP_SHARED_MEMORY, GROUP_SCOPE | GROUP_SYNC);
  275. const uint positionIndex = obj.m_positionsStart + svDispatchThreadId;
  276. const uint pos = g_positions[positionIndex];
  277. InterlockedMin(g_aabb.m_min, pos);
  278. InterlockedMax(g_aabb.m_max, pos);
  279. Barrier(GROUP_SHARED_MEMORY, GROUP_SCOPE | GROUP_SYNC);
  280. InterlockedMin(g_aabbs[g_consts.m_objectIndex].m_min, g_aabb.m_min);
  281. InterlockedMax(g_aabbs[g_consts.m_objectIndex].m_max, g_aabb.m_max);
  282. }
  283. )";
  284. constexpr U32 kObjectCount = 1000 * 64;
  285. constexpr U32 kPositionsPerObject = 10 * 64; // 1 * 1024;
  286. constexpr U32 kThreadCount = 64;
  287. ShaderProgramPtr prog;
  288. if(bWorkgraphs)
  289. {
  290. ShaderPtr shader = createShader(kSrc, ShaderType::kWorkGraph);
  291. ShaderProgramInitInfo progInit;
  292. Array<WorkGraphNodeSpecialization, 2> specializations = {
  293. {{"main", UVec3((kObjectCount + kThreadCount - 1) / kThreadCount, 1, 1)},
  294. {"computeAabb", UVec3((kPositionsPerObject + (kThreadCount - 1)) / kThreadCount, 1, 1)}}};
  295. progInit.m_workGraph.m_nodeSpecializations = specializations;
  296. progInit.m_workGraph.m_shader = shader.get();
  297. prog = GrManager::getSingleton().newShaderProgram(progInit);
  298. }
  299. else
  300. {
  301. ShaderPtr shader = createShader(kComputeSrc, ShaderType::kCompute);
  302. ShaderProgramInitInfo progInit;
  303. progInit.m_computeShader = shader.get();
  304. prog = GrManager::getSingleton().newShaderProgram(progInit);
  305. }
  306. struct Aabb
  307. {
  308. U32 m_min = kMaxU32;
  309. U32 m_max = 0;
  310. Bool operator==(const Aabb&) const = default;
  311. };
  312. struct Object
  313. {
  314. U32 m_positionsStart; // Points to g_positions
  315. U32 m_positionCount;
  316. };
  317. // Objects
  318. DynamicArray<Object> objects;
  319. objects.resize(kObjectCount);
  320. U32 positionCount = 0;
  321. for(Object& obj : objects)
  322. {
  323. obj.m_positionsStart = positionCount;
  324. obj.m_positionCount = kPositionsPerObject;
  325. positionCount += obj.m_positionCount;
  326. }
  327. printf("Obj count %u, pos count %u\n", kObjectCount, positionCount);
  328. BufferPtr objBuff = createBuffer(BufferUsageBit::kSrvCompute, ConstWeakArray(objects), "Objects");
  329. // AABBs
  330. BufferPtr aabbsBuff = createBuffer(BufferUsageBit::kUavCompute, Aabb(), kObjectCount, "AABBs");
  331. // Positions
  332. GrDynamicArray<U32> positions;
  333. positions.resize(positionCount);
  334. positionCount = 0;
  335. for(U32 iobj = 0; iobj < kObjectCount; ++iobj)
  336. {
  337. const Object& obj = objects[iobj];
  338. const U32 min = getRandomRange<U32>(0, kMaxU32 / 2 - 1);
  339. const U32 max = getRandomRange<U32>(kMaxU32 / 2, kMaxU32);
  340. for(U32 ipos = obj.m_positionsStart; ipos < obj.m_positionsStart + obj.m_positionCount; ++ipos)
  341. {
  342. positions[ipos] = getRandomRange<U32>(min, max);
  343. positions[ipos] = iobj;
  344. }
  345. positionCount += obj.m_positionCount;
  346. }
  347. BufferPtr posBuff = createBuffer(BufferUsageBit::kSrvCompute, ConstWeakArray(positions), "Positions");
  348. BufferPtr scratchBuff;
  349. if(bWorkgraphs)
  350. {
  351. BufferInitInfo scratchInit("scratch");
  352. scratchInit.m_size = prog->getWorkGraphMemoryRequirements();
  353. scratchInit.m_usage = BufferUsageBit::kAllUav;
  354. scratchBuff = GrManager::getSingleton().newBuffer(scratchInit);
  355. }
  356. // Execute
  357. const U32 iterationsPerCmdb = (!bBenchmark) ? 1 : 100u;
  358. const U32 iterationCount = (!bBenchmark) ? iterationsPerCmdb : iterationsPerCmdb * 1;
  359. runBenchmark(iterationCount, iterationsPerCmdb, bBenchmark, [&](CommandBuffer& cmdb) {
  360. const BufferBarrierInfo barr = {BufferView(aabbsBuff.get()), BufferUsageBit::kUavCompute, BufferUsageBit::kUavCompute};
  361. cmdb.setPipelineBarrier({}, {&barr, 1}, {});
  362. if(bWorkgraphs)
  363. {
  364. struct FirstNodeRecord
  365. {
  366. UVec3 m_gridSize;
  367. };
  368. Array<FirstNodeRecord, 1> records;
  369. records[0].m_gridSize = UVec3((objects.getSize() + kThreadCount - 1) / kThreadCount, 1, 1);
  370. cmdb.bindShaderProgram(prog.get());
  371. cmdb.bindUav(0, 0, BufferView(aabbsBuff.get()));
  372. cmdb.bindSrv(0, 0, BufferView(objBuff.get()));
  373. cmdb.bindSrv(1, 0, BufferView(posBuff.get()));
  374. cmdb.dispatchGraph(BufferView(scratchBuff.get()), records.getBegin(), records.getSize(), sizeof(records[0]));
  375. }
  376. else
  377. {
  378. cmdb.bindShaderProgram(prog.get());
  379. cmdb.bindUav(0, 0, BufferView(aabbsBuff.get()));
  380. cmdb.bindSrv(0, 0, BufferView(objBuff.get()));
  381. cmdb.bindSrv(1, 0, BufferView(posBuff.get()));
  382. for(U32 iobj = 0; iobj < kObjectCount; ++iobj)
  383. {
  384. const UVec4 pc(iobj);
  385. cmdb.setFastConstants(&pc, sizeof(pc));
  386. cmdb.dispatchCompute((objects[iobj].m_positionCount + kThreadCount - 1) / kThreadCount, 1, 1);
  387. }
  388. }
  389. });
  390. // Check
  391. DynamicArray<Aabb> aabbs;
  392. readBuffer(aabbsBuff, aabbs);
  393. for(U32 i = 0; i < kObjectCount; ++i)
  394. {
  395. const Object& obj = objects[i];
  396. Aabb aabb;
  397. for(U32 ipos = obj.m_positionsStart; ipos < obj.m_positionsStart + obj.m_positionCount; ++ipos)
  398. {
  399. aabb.m_min = min(aabb.m_min, positions[ipos]);
  400. aabb.m_max = max(aabb.m_max, positions[ipos]);
  401. }
  402. if(aabb != aabbs[i])
  403. {
  404. printf("%u: %u %u | %u %u\n", i, aabb.m_min, aabbs[i].m_min, aabb.m_max, aabbs[i].m_max);
  405. }
  406. ANKI_TEST_EXPECT_EQ(aabb, aabbs[i]);
  407. }
  408. }
  409. commonDestroy();
  410. }
  411. ANKI_TEST(Gr, WorkGraphsWorkDrain)
  412. {
  413. Bool bBenchmark, bWorkgraphs;
  414. commonInitWg(bBenchmark, bWorkgraphs);
  415. #define TEX_SIZE_X 4096u
  416. #define TEX_SIZE_Y 4096u
  417. #define TILE_SIZE_X 32u
  418. #define TILE_SIZE_Y 32u
  419. #define TILE_COUNT_X (TEX_SIZE_X / TILE_SIZE_X)
  420. #define TILE_COUNT_Y (TEX_SIZE_Y / TILE_SIZE_Y)
  421. #define TILE_COUNT (TILE_COUNT_X * TILE_COUNT_Y)
  422. {
  423. // Create WG prog
  424. ShaderProgramPtr wgProg;
  425. if(bWorkgraphs)
  426. {
  427. ShaderPtr wgShader = loadShader(ANKI_SOURCE_DIRECTORY "/Tests/Gr/WorkDrainWg.hlsl", ShaderType::kWorkGraph);
  428. ShaderProgramInitInfo progInit;
  429. progInit.m_workGraph.m_shader = wgShader.get();
  430. Array<WorkGraphNodeSpecialization, 1> specializations = {{{"main", UVec3(TILE_COUNT_X, TILE_COUNT_Y, 1)}}};
  431. progInit.m_workGraph.m_nodeSpecializations = specializations;
  432. wgProg = GrManager::getSingleton().newShaderProgram(progInit);
  433. }
  434. // Scratch buff
  435. BufferPtr scratchBuff;
  436. if(bWorkgraphs)
  437. {
  438. BufferInitInfo scratchInit("scratch");
  439. scratchInit.m_size = wgProg->getWorkGraphMemoryRequirements();
  440. scratchInit.m_usage = BufferUsageBit::kAllUav | BufferUsageBit::kAllSrv;
  441. scratchBuff = GrManager::getSingleton().newBuffer(scratchInit);
  442. }
  443. // Create compute progs
  444. ShaderProgramPtr compProg0, compProg1;
  445. {
  446. ShaderPtr shader =
  447. loadShader(ANKI_SOURCE_DIRECTORY "/Tests/Gr/WorkDrainCompute.hlsl", ShaderType::kCompute, Array<CString, 1>{"-DFIRST"});
  448. ShaderProgramInitInfo progInit;
  449. progInit.m_computeShader = shader.get();
  450. compProg0 = GrManager::getSingleton().newShaderProgram(progInit);
  451. shader = loadShader(ANKI_SOURCE_DIRECTORY "/Tests/Gr/WorkDrainCompute.hlsl", ShaderType::kCompute);
  452. progInit.m_computeShader = shader.get();
  453. compProg1 = GrManager::getSingleton().newShaderProgram(progInit);
  454. }
  455. // Create texture 2D
  456. TexturePtr tex;
  457. {
  458. DynamicArray<Vec4> data;
  459. data.resize(TEX_SIZE_X * TEX_SIZE_Y, Vec4(1.0f));
  460. data[10] = Vec4(1.1f, 2.06f, 3.88f, 0.5f);
  461. TextureInitInfo texInit("Tex");
  462. texInit.m_width = TEX_SIZE_X;
  463. texInit.m_height = TEX_SIZE_Y;
  464. texInit.m_format = Format::kR32G32B32A32_Sfloat;
  465. texInit.m_usage = TextureUsageBit::kAllUav | TextureUsageBit::kAllSrv;
  466. tex = createTexture2d(texInit, ConstWeakArray(data));
  467. }
  468. // Create counter buff
  469. BufferPtr threadgroupCountBuff = createBuffer(BufferUsageBit::kUavCompute, U32(0u), 1);
  470. // Result buffers
  471. BufferPtr tileMax = createBuffer(BufferUsageBit::kAllUav | BufferUsageBit::kAllSrv, Vec4(0.1f), TILE_COUNT);
  472. BufferPtr finalMax = createBuffer(BufferUsageBit::kAllUav | BufferUsageBit::kAllSrv, Vec4(0.1f), 1);
  473. const U32 iterationsPerCmdb = (!bBenchmark) ? 1 : 100u;
  474. const U32 iterationCount = (!bBenchmark) ? 1 : iterationsPerCmdb * 100;
  475. runBenchmark(iterationCount, iterationsPerCmdb, bBenchmark, [&](CommandBuffer& cmdb) {
  476. BufferBarrierInfo barr = {BufferView(tileMax.get()), BufferUsageBit::kUavCompute, BufferUsageBit::kUavCompute};
  477. cmdb.setPipelineBarrier({}, {&barr, 1}, {});
  478. cmdb.bindSrv(0, 0, TextureView(tex.get(), TextureSubresourceDesc::all()));
  479. cmdb.bindUav(0, 0, BufferView(tileMax.get()));
  480. cmdb.bindUav(1, 0, BufferView(finalMax.get()));
  481. cmdb.bindUav(2, 0, BufferView(threadgroupCountBuff.get()));
  482. if(bWorkgraphs)
  483. {
  484. cmdb.bindShaderProgram(wgProg.get());
  485. struct FirstNodeRecord
  486. {
  487. UVec3 m_gridSize;
  488. };
  489. Array<FirstNodeRecord, 1> records;
  490. records[0].m_gridSize = UVec3(TILE_COUNT_X, TILE_COUNT_Y, 1);
  491. cmdb.dispatchGraph(BufferView(scratchBuff.get()), records.getBegin(), records.getSize(), sizeof(records[0]));
  492. }
  493. else
  494. {
  495. cmdb.bindShaderProgram(compProg0.get());
  496. cmdb.dispatchCompute(TILE_COUNT_X, TILE_COUNT_Y, 1);
  497. barr = {BufferView(tileMax.get()), BufferUsageBit::kUavCompute, BufferUsageBit::kUavCompute};
  498. cmdb.setPipelineBarrier({}, {&barr, 1}, {});
  499. cmdb.bindShaderProgram(compProg1.get());
  500. cmdb.dispatchCompute(1, 1, 1);
  501. }
  502. });
  503. validateBuffer2(finalMax, Vec4(1.1f, 2.06f, 3.88f, 1.0f));
  504. }
  505. commonDestroy();
  506. }
  507. ANKI_TEST(Gr, WorkGraphsOverhead)
  508. {
  509. Bool bBenchmark, bWorkgraphs;
  510. commonInitWg(bBenchmark, bWorkgraphs);
  511. constexpr U32 kMaxCandidates = 100 * 1024;
  512. {
  513. // Create compute progs
  514. ShaderProgramPtr compProg;
  515. {
  516. ShaderPtr shader =
  517. loadShader(ANKI_SOURCE_DIRECTORY "/Tests/Gr/FindPrimeNumbers.hlsl", ShaderType::kCompute, Array<CString, 1>{"-DWORKGRAPHS=0"});
  518. ShaderProgramInitInfo progInit;
  519. progInit.m_computeShader = shader.get();
  520. compProg = GrManager::getSingleton().newShaderProgram(progInit);
  521. }
  522. // Create WG prog
  523. ShaderProgramPtr wgProg;
  524. if(bWorkgraphs)
  525. {
  526. ShaderPtr shader =
  527. loadShader(ANKI_SOURCE_DIRECTORY "/Tests/Gr/FindPrimeNumbers.hlsl", ShaderType::kWorkGraph, Array<CString, 1>{"-DWORKGRAPHS=1"});
  528. ShaderProgramInitInfo progInit;
  529. progInit.m_workGraph.m_shader = shader.get();
  530. Array<WorkGraphNodeSpecialization, 1> specializations = {{{"main", UVec3(kMaxCandidates, 1, 1)}}};
  531. progInit.m_workGraph.m_nodeSpecializations = specializations;
  532. wgProg = GrManager::getSingleton().newShaderProgram(progInit);
  533. }
  534. // Scratch buff
  535. BufferPtr scratchBuff;
  536. if(bWorkgraphs)
  537. {
  538. BufferInitInfo scratchInit("scratch");
  539. scratchInit.m_size = wgProg->getWorkGraphMemoryRequirements();
  540. scratchInit.m_usage = BufferUsageBit::kAllUav | BufferUsageBit::kAllSrv;
  541. scratchBuff = GrManager::getSingleton().newBuffer(scratchInit);
  542. }
  543. BufferPtr miscBuff = createBuffer(BufferUsageBit::kUavCompute, UVec3(0, kMaxCandidates, 0), 2);
  544. BufferPtr primeNumbersBuff = createBuffer(BufferUsageBit::kUavCompute, U32(0u), kMaxCandidates + 1);
  545. const U32 iterationsPerCmdb = (!bBenchmark) ? 1 : 100u;
  546. const U32 iterationCount = (!bBenchmark) ? 1 : iterationsPerCmdb * 50;
  547. runBenchmark(iterationCount, iterationsPerCmdb, bBenchmark, [&](CommandBuffer& cmdb) {
  548. BufferBarrierInfo barr = {BufferView(primeNumbersBuff.get()), BufferUsageBit::kUavCompute, BufferUsageBit::kUavCompute};
  549. cmdb.setPipelineBarrier({}, {&barr, 1}, {});
  550. cmdb.bindUav(0, 0, BufferView(primeNumbersBuff.get()));
  551. cmdb.bindUav(1, 0, BufferView(miscBuff.get()));
  552. if(bWorkgraphs)
  553. {
  554. cmdb.bindShaderProgram(wgProg.get());
  555. struct FirstNodeRecord
  556. {
  557. UVec3 m_gridSize;
  558. };
  559. Array<FirstNodeRecord, 1> records;
  560. records[0].m_gridSize = UVec3(kMaxCandidates, 1, 1);
  561. cmdb.dispatchGraph(BufferView(scratchBuff.get()), records.getBegin(), records.getSize(), sizeof(records[0]));
  562. }
  563. else
  564. {
  565. cmdb.bindShaderProgram(compProg.get());
  566. cmdb.dispatchCompute(kMaxCandidates, 1, 1);
  567. }
  568. });
  569. DynamicArray<U32> values;
  570. readBuffer(primeNumbersBuff, values);
  571. values.resize(values[0] + 1);
  572. std::sort(values.getBegin() + 1, values.getEnd(), [](U32 a, U32 b) {
  573. return a < b;
  574. });
  575. auto isPrime = [](int N) {
  576. if(N <= 1)
  577. {
  578. return false;
  579. }
  580. for(int i = 2; i < N / 2; i++)
  581. {
  582. if(N % i == 0)
  583. {
  584. return false;
  585. }
  586. }
  587. return true;
  588. };
  589. DynamicArray<U32> values2;
  590. values2.resize(1, 0);
  591. for(U32 i = 0; i < kMaxCandidates; ++i)
  592. {
  593. if(isPrime(i))
  594. {
  595. ++values2[0];
  596. values2.emplaceBack(i);
  597. }
  598. }
  599. ANKI_TEST_EXPECT_EQ(values.getSize(), values2.getSize());
  600. for(U32 i = 0; i < values2.getSize(); ++i)
  601. {
  602. ANKI_TEST_EXPECT_EQ(values[i], values2[i]);
  603. }
  604. }
  605. commonDestroy();
  606. }
  607. ANKI_TEST(Gr, WorkGraphsJobManager)
  608. {
  609. Bool bBenchmark, bWorkgraphs;
  610. // CVarSet::getSingleton().setMultiple(Array<const Char*, 2>{"Device", "1"});
  611. commonInitWg(bBenchmark, bWorkgraphs);
  612. const U32 queueRingBufferSize = nextPowerOfTwo(2 * 1024 * 1024);
  613. const U32 initialWorkItemCount = 128 * 1024;
  614. {
  615. // Create compute progs
  616. ShaderProgramPtr compProg;
  617. {
  618. ShaderPtr shader = loadShader(ANKI_SOURCE_DIRECTORY "/Tests/Gr/JobManagerCompute.hlsl", ShaderType::kCompute);
  619. ShaderProgramInitInfo progInit;
  620. progInit.m_computeShader = shader.get();
  621. compProg = GrManager::getSingleton().newShaderProgram(progInit);
  622. }
  623. ShaderProgramPtr wgProg;
  624. if(bWorkgraphs)
  625. {
  626. ShaderPtr shader = loadShader(ANKI_SOURCE_DIRECTORY "/Tests/Gr/JobManagerWg.hlsl", ShaderType::kWorkGraph);
  627. ShaderProgramInitInfo progInit;
  628. Array<WorkGraphNodeSpecialization, 1> specializations = {{{"main", UVec3((initialWorkItemCount + 64 - 1) / 64, 1, 1)}}};
  629. progInit.m_workGraph.m_nodeSpecializations = specializations;
  630. progInit.m_workGraph.m_shader = shader.get();
  631. wgProg = GrManager::getSingleton().newShaderProgram(progInit);
  632. }
  633. // Scratch buff
  634. BufferPtr scratchBuff;
  635. if(bWorkgraphs)
  636. {
  637. BufferInitInfo scratchInit("scratch");
  638. scratchInit.m_size = wgProg->getWorkGraphMemoryRequirements();
  639. scratchInit.m_usage = BufferUsageBit::kAllUav | BufferUsageBit::kAllSrv;
  640. scratchBuff = GrManager::getSingleton().newBuffer(scratchInit);
  641. }
  642. DynamicArray<U32> initialWorkItems;
  643. U32 finalValue = 0;
  644. U32 workItemCount = 0;
  645. {
  646. initialWorkItems.resize(initialWorkItemCount);
  647. for(U32 i = 0; i < initialWorkItems.getSize(); ++i)
  648. {
  649. const Bool bDeterministic = bBenchmark;
  650. const U32 level = ((bDeterministic) ? i : rand()) % 4;
  651. const U32 payload = ((bDeterministic) ? 1 : rand()) % 4;
  652. initialWorkItems[i] = (level << 16) | payload;
  653. }
  654. DynamicArray<U32> initialWorkItems2 = initialWorkItems;
  655. while(initialWorkItems2.getSize() > 0)
  656. {
  657. const U32 workItem = initialWorkItems2.getBack();
  658. initialWorkItems2.popBack();
  659. const U32 level = workItem >> 16u;
  660. const U32 payload = workItem & 0xFFFFu;
  661. ++workItemCount;
  662. if(level == 0)
  663. {
  664. finalValue += payload;
  665. }
  666. else
  667. {
  668. U32 newWorkItem = (level - 1) << 16u;
  669. newWorkItem |= payload;
  670. for(U32 i = 0; i < 4; ++i)
  671. {
  672. initialWorkItems2.emplaceBack(newWorkItem);
  673. }
  674. }
  675. };
  676. }
  677. BufferPtr resultBuff = createBuffer<U32>(BufferUsageBit::kAllUav, 0u, 2);
  678. BufferPtr queueRingBuff;
  679. if(!bWorkgraphs)
  680. {
  681. queueRingBuff = createBuffer<U32>(BufferUsageBit::kAllUav, 0u, queueRingBufferSize);
  682. BufferPtr tempBuff = createBuffer<U32>(BufferUsageBit::kCopySource, initialWorkItems);
  683. CommandBufferPtr cmdb = GrManager::getSingleton().newCommandBuffer(CommandBufferInitInfo());
  684. cmdb->copyBufferToBuffer(BufferView(tempBuff.get(), 0, initialWorkItems.getSizeInBytes()),
  685. BufferView(queueRingBuff.get(), 0, initialWorkItems.getSizeInBytes()));
  686. cmdb->endRecording();
  687. FencePtr fence;
  688. GrManager::getSingleton().submit(cmdb.get(), {}, &fence);
  689. ANKI_TEST_EXPECT_EQ(fence->clientWait(kMaxSecond), true);
  690. }
  691. BufferPtr initialWorkItemsBuff;
  692. if(bWorkgraphs)
  693. {
  694. initialWorkItemsBuff = createBuffer<U32>(BufferUsageBit::kAllUav, 0u, initialWorkItemCount);
  695. BufferPtr tempBuff = createBuffer<U32>(BufferUsageBit::kCopySource, initialWorkItems);
  696. CommandBufferPtr cmdb = GrManager::getSingleton().newCommandBuffer(CommandBufferInitInfo());
  697. cmdb->copyBufferToBuffer(BufferView(tempBuff.get(), 0, initialWorkItems.getSizeInBytes()),
  698. BufferView(initialWorkItemsBuff.get(), 0, initialWorkItems.getSizeInBytes()));
  699. cmdb->endRecording();
  700. FencePtr fence;
  701. GrManager::getSingleton().submit(cmdb.get(), {}, &fence);
  702. ANKI_TEST_EXPECT_EQ(fence->clientWait(kMaxSecond), true);
  703. }
  704. BufferPtr queueBuff;
  705. if(!bWorkgraphs)
  706. {
  707. struct Queue
  708. {
  709. U32 m_spinlock;
  710. U32 m_head;
  711. U32 m_tail;
  712. U32 m_pendingWork;
  713. };
  714. Queue q = {};
  715. q.m_head = initialWorkItems.getSize();
  716. queueBuff = createBuffer(BufferUsageBit::kAllUav, q, 1);
  717. }
  718. ANKI_TEST_LOGI("Init complete");
  719. const U32 iterationsPerCmdb = 1;
  720. const U32 iterationCount = 1;
  721. runBenchmark(iterationCount, iterationsPerCmdb, bBenchmark, [&](CommandBuffer& cmdb) {
  722. if(!bWorkgraphs)
  723. {
  724. cmdb.bindShaderProgram(compProg.get());
  725. cmdb.bindUav(0, 0, BufferView(queueBuff.get()));
  726. cmdb.bindUav(1, 0, BufferView(queueRingBuff.get()));
  727. cmdb.bindUav(2, 0, BufferView(resultBuff.get()));
  728. UVec4 consts(queueRingBufferSize - 1);
  729. cmdb.setFastConstants(&consts, sizeof(consts));
  730. cmdb.dispatchCompute(256, 1, 1);
  731. }
  732. else
  733. {
  734. cmdb.bindShaderProgram(wgProg.get());
  735. cmdb.bindSrv(0, 0, BufferView(initialWorkItemsBuff.get()));
  736. cmdb.bindUav(0, 0, BufferView(resultBuff.get()));
  737. struct FirstNodeRecord
  738. {
  739. UVec3 m_gridSize;
  740. };
  741. Array<FirstNodeRecord, 1> records;
  742. records[0].m_gridSize = UVec3((initialWorkItemCount + 64 - 1) / 64, 1, 1);
  743. cmdb.dispatchGraph(BufferView(scratchBuff.get()), records.getBegin(), records.getSize(), sizeof(records[0]));
  744. }
  745. });
  746. DynamicArray<U32> result;
  747. readBuffer(resultBuff, result);
  748. printf("expecting %u, got %u. Error %u\n", finalValue, result[0], result[1]);
  749. ANKI_TEST_EXPECT_EQ(result[0], finalValue);
  750. ANKI_TEST_EXPECT_EQ(result[1], 0);
  751. }
  752. commonDestroy();
  753. }