Browse Source

D3D: Workgraph tests

Panagiotis Christopoulos Charitos 1 year ago
parent
commit
8859c32cde

+ 16 - 2
AnKi/Gr/D3D/D3DShaderProgram.cpp

@@ -63,9 +63,9 @@ Error ShaderProgramImpl::init(const ShaderProgramInitInfo& inf)
 			}
 		}
 	}
-	else if(inf.m_workGraphShader)
+	else if(inf.m_workGraph.m_shader)
 	{
-		m_shaders.emplaceBack(inf.m_workGraphShader);
+		m_shaders.emplaceBack(inf.m_workGraph.m_shader);
 	}
 	else
 	{
@@ -154,6 +154,19 @@ Error ShaderProgramImpl::init(const ShaderProgramInitInfo& inf)
 		wgSubObj->IncludeAllAvailableNodes(); // Auto populate the graph
 		wgSubObj->SetProgramName(wgName);
 
+		GrDynamicArray<Array<WChar, 128>> nodeNames;
+		nodeNames.resize(inf.m_workGraph.m_nodeSpecializations.getSize());
+		for(U32 i = 0; i < inf.m_workGraph.m_nodeSpecializations.getSize(); ++i)
+		{
+			const WorkGraphNodeSpecialization& specialization = inf.m_workGraph.m_nodeSpecializations[i];
+			specialization.m_nodeName.toWideChars(nodeNames[i].getBegin(), nodeNames[i].getSize());
+			CD3DX12_BROADCASTING_LAUNCH_NODE_OVERRIDES* spec = wgSubObj->CreateBroadcastingLaunchNodeOverrides(nodeNames[i].getBegin());
+
+			ANKI_ASSERT(specialization.m_maxNodeDispatchGrid > UVec3(0u));
+			spec->MaxDispatchGrid(specialization.m_maxNodeDispatchGrid.x(), specialization.m_maxNodeDispatchGrid.y(),
+								  specialization.m_maxNodeDispatchGrid.z());
+		}
+
 		// Create state obj
 		ANKI_D3D_CHECK(getDevice().CreateStateObject(stateObj, IID_PPV_ARGS(&m_workGraph.m_stateObject)));
 
@@ -167,6 +180,7 @@ Error ShaderProgramImpl::init(const ShaderProgramInitInfo& inf)
 		const UINT wgIndex = spWGProps->GetWorkGraphIndex(wgName);
 		D3D12_WORK_GRAPH_MEMORY_REQUIREMENTS memReqs;
 		spWGProps->GetWorkGraphMemoryRequirements(wgIndex, &memReqs);
+		ANKI_ASSERT(spWGProps->GetNumEntrypoints(wgIndex) == 1);
 
 		m_workGraphScratchBufferSize = memReqs.MaxSizeInBytes;
 	}

+ 1 - 1
AnKi/Gr/ShaderProgram.cpp

@@ -111,7 +111,7 @@ Bool ShaderProgramInitInfo::isValid() const
 		return false;
 	}
 
-	const Bool workGraph = m_workGraphShader != nullptr;
+	const Bool workGraph = m_workGraph.m_shader != nullptr;
 
 	const U32 options = !!graphicsMask + compute + !!rtMask + workGraph;
 	if(options != 1)

+ 17 - 7
AnKi/Gr/ShaderProgram.h

@@ -22,13 +22,11 @@ public:
 };
 
 /// @memberof ShaderProgramInitInfo
-class RayTracingShaders
+class WorkGraphNodeSpecialization
 {
 public:
-	WeakArray<Shader*> m_rayGenShaders;
-	WeakArray<Shader*> m_missShaders;
-	WeakArray<RayTracingHitGroup> m_hitGroups;
-	U32 m_maxRecursionDepth = 1;
+	CString m_nodeName;
+	UVec3 m_maxNodeDispatchGrid;
 };
 
 /// ShaderProgram init info.
@@ -42,10 +40,22 @@ public:
 	Shader* m_computeShader = nullptr;
 
 	/// Option 3
-	RayTracingShaders m_rayTracingShaders;
+	class
+	{
+	public:
+		WeakArray<Shader*> m_rayGenShaders;
+		WeakArray<Shader*> m_missShaders;
+		WeakArray<RayTracingHitGroup> m_hitGroups;
+		U32 m_maxRecursionDepth = 1;
+	} m_rayTracingShaders;
 
 	/// Option 4
-	Shader* m_workGraphShader = nullptr;
+	class
+	{
+	public:
+		Shader* m_shader = nullptr;
+		ConstWeakArray<WorkGraphNodeSpecialization> m_nodeSpecializations;
+	} m_workGraph;
 
 	ShaderProgramInitInfo(CString name = {})
 		: GrBaseInitInfo(name)

+ 19 - 2
AnKi/ShaderCompiler/ShaderCompiler.cpp

@@ -560,8 +560,25 @@ Error doReflectionDxil(ConstWeakArray<U8> dxil, ShaderType type, ShaderReflectio
 				return Error::kUserData;
 			}
 
-			refl.m_descriptor.m_bindings[bindDesc.Space][refl.m_descriptor.m_bindingCounts[bindDesc.Space]] = akBinding;
-			++refl.m_descriptor.m_bindingCounts[bindDesc.Space];
+			Bool skip = false;
+			if(isLib)
+			{
+				// Search if the binding exists because it may repeat
+				for(U32 i = 0; i < refl.m_descriptor.m_bindingCounts[bindDesc.Space]; ++i)
+				{
+					if(refl.m_descriptor.m_bindings[bindDesc.Space][i] == akBinding)
+					{
+						skip = true;
+						break;
+					}
+				}
+			}
+
+			if(!skip)
+			{
+				refl.m_descriptor.m_bindings[bindDesc.Space][refl.m_descriptor.m_bindingCounts[bindDesc.Space]] = akBinding;
+				++refl.m_descriptor.m_bindingCounts[bindDesc.Space];
+			}
 		}
 	}
 

+ 0 - 1
AnKi/Util/CMakeLists.txt

@@ -5,7 +5,6 @@ set(sources
 	Filesystem.cpp
 	MemoryPool.cpp
 	System.cpp
-	HighRezTimer.cpp
 	ThreadPool.cpp
 	ThreadHive.cpp
 	Hash.cpp

+ 0 - 36
AnKi/Util/HighRezTimer.cpp

@@ -1,36 +0,0 @@
-// Copyright (C) 2009-present, Panagiotis Christopoulos Charitos and contributors.
-// All rights reserved.
-// Code licensed under the BSD License.
-// http://www.anki3d.org/LICENSE
-
-#include <AnKi/Util/HighRezTimer.h>
-#include <AnKi/Util/Assert.h>
-
-namespace anki {
-
-void HighRezTimer::start()
-{
-	m_startTime = getCurrentTime();
-	m_stopTime = 0.0;
-}
-
-void HighRezTimer::stop()
-{
-	ANKI_ASSERT(m_startTime != 0.0);
-	ANKI_ASSERT(m_stopTime == 0.0);
-	m_stopTime = getCurrentTime();
-}
-
-Second HighRezTimer::getElapsedTime() const
-{
-	if(m_stopTime == 0.0)
-	{
-		return getCurrentTime() - m_startTime;
-	}
-	else
-	{
-		return m_stopTime - m_startTime;
-	}
-}
-
-} // end namespace anki

+ 16 - 3
AnKi/Util/HighRezTimer.h

@@ -5,6 +5,7 @@
 
 #pragma once
 
+#include <AnKi/Util/Assert.h>
 #include <AnKi/Util/StdTypes.h>
 
 namespace anki {
@@ -17,13 +18,25 @@ class HighRezTimer
 {
 public:
 	/// Start the timer
-	void start();
+	void start()
+	{
+		m_startTime = getCurrentTime();
+		m_stopTime = 0.0;
+	}
 
 	/// Stop the timer
-	void stop();
+	void stop()
+	{
+		ANKI_ASSERT(m_startTime != 0.0);
+		ANKI_ASSERT(m_stopTime == 0.0);
+		m_stopTime = getCurrentTime();
+	}
 
 	/// Get the time elapsed between start and stop (if its stopped) or between start and the current time.
-	Second getElapsedTime() const;
+	Second getElapsedTime() const
+	{
+		return (m_stopTime == 0.0) ? getCurrentTime() - m_startTime : m_stopTime - m_startTime;
+	}
 
 	/// Get the current date's seconds
 	static Second getCurrentTime();

+ 4 - 4
Tests/Gr/Gr.cpp

@@ -264,9 +264,9 @@ void main()
 		texInit.m_usage = TextureUsageBit::kSampledCompute | TextureUsageBit::kTransferDestination;
 		TexturePtr tex = createTexture2d(texInit, kMagicVec * 2.0f);
 
-		BufferPtr buff = createBuffer(BufferUsageBit::kAllTexel, kMagicVec * 2.0f, "buff");
-		BufferPtr rwstructured = createBuffer(BufferUsageBit::kAllStorage, kInvalidVec, "rwstructured");
-		BufferPtr rwbuff = createBuffer(BufferUsageBit::kAllTexel, kInvalidVec, "rwbuff");
+		BufferPtr buff = createBuffer(BufferUsageBit::kAllTexel, kMagicVec * 2.0f, 1, "buff");
+		BufferPtr rwstructured = createBuffer(BufferUsageBit::kAllStorage, kInvalidVec, 1, "rwstructured");
+		BufferPtr rwbuff = createBuffer(BufferUsageBit::kAllTexel, kInvalidVec, 1, "rwbuff");
 
 		Array<TexturePtr, 3> rwtex;
 
@@ -275,7 +275,7 @@ void main()
 		rwtex[1] = createTexture2d(texInit, kInvalidVec);
 		rwtex[2] = createTexture2d(texInit, kInvalidVec);
 
-		BufferPtr consts = createBuffer(BufferUsageBit::kUniformCompute, kMagicVec * 3.0f, "consts");
+		BufferPtr consts = createBuffer(BufferUsageBit::kUniformCompute, kMagicVec * 3.0f, 1, "consts");
 
 		SamplerInitInfo samplInit;
 		SamplerPtr sampler = GrManager::getSingleton().newSampler(samplInit);

+ 48 - 25
Tests/Gr/GrCommon.h

@@ -65,18 +65,26 @@ inline ShaderProgramPtr createVertFragProg(CString vert, CString frag)
 const U kWidth = 1024;
 const U kHeight = 768;
 
-inline void commonInit()
+inline void commonInit(Bool validation = true)
 {
 	DefaultMemoryPool::allocateSingleton(allocAligned, nullptr);
 	ShaderCompilerMemoryPool::allocateSingleton(allocAligned, nullptr);
 	g_windowWidthCVar.set(kWidth);
 	g_windowHeightCVar.set(kHeight);
-	g_validationCVar.set(true);
 	g_vsyncCVar.set(false);
 	g_debugMarkersCVar.set(true);
+	if(validation)
+	{
+		g_validationCVar.set(true);
+		[[maybe_unused]] const Error err = CVarSet::getSingleton().setMultiple(Array<const Char*, 2>{"GpuValidation", "1"});
+	}
 
 	initWindow();
 	ANKI_TEST_EXPECT_NO_ERR(Input::allocateSingleton().init());
+
+	Input::allocateSingleton();
+	ANKI_TEST_EXPECT_NO_ERR(Input::getSingleton().init());
+
 	initGrManager();
 }
 
@@ -85,6 +93,7 @@ inline void commonDestroy()
 	GrManager::freeSingleton();
 	Input::freeSingleton();
 	NativeWindow::freeSingleton();
+	Input::freeSingleton();
 	ShaderCompilerMemoryPool::freeSingleton();
 	DefaultMemoryPool::freeSingleton();
 }
@@ -92,35 +101,38 @@ inline void commonDestroy()
 template<typename T>
 inline BufferPtr createBuffer(BufferUsageBit usage, ConstWeakArray<T> data, CString name = {})
 {
-	BufferInitInfo buffInit;
-	if(!name.isEmpty())
-	{
-		buffInit.setName(name);
-	}
-
-	buffInit.m_mapAccess = BufferMapAccessBit::kWrite;
-	buffInit.m_usage = usage | BufferUsageBit::kTransferSource;
-	buffInit.m_size = data.getSizeInBytes();
-
-	BufferPtr buff = GrManager::getSingleton().newBuffer(buffInit);
-
-	T* inData = static_cast<T*>(buff->map(0, kMaxPtrSize, BufferMapAccessBit::kWrite));
+	BufferPtr copyBuff =
+		GrManager::getSingleton().newBuffer(BufferInitInfo(data.getSizeInBytes(), BufferUsageBit::kTransferSource, BufferMapAccessBit::kWrite));
 
+	T* inData = static_cast<T*>(copyBuff->map(0, kMaxPtrSize, BufferMapAccessBit::kWrite));
 	for(U32 i = 0; i < data.getSize(); ++i)
 	{
 		inData[i] = data[i];
 	}
+	copyBuff->unmap();
+
+	BufferPtr buff = GrManager::getSingleton().newBuffer(
+		BufferInitInfo(data.getSizeInBytes(), usage | BufferUsageBit::kTransferSource, BufferMapAccessBit::kNone, name));
+
+	CommandBufferInitInfo cmdbInit;
+	cmdbInit.m_flags |= CommandBufferFlag::kSmallBatch;
+	CommandBufferPtr cmdb = GrManager::getSingleton().newCommandBuffer(cmdbInit);
+	cmdb->copyBufferToBuffer(BufferView(copyBuff.get()), BufferView(buff.get()));
+	cmdb->endRecording();
 
-	buff->unmap();
+	FencePtr fence;
+	GrManager::getSingleton().submit(cmdb.get(), {}, &fence);
+	fence->clientWait(kMaxSecond);
 
 	return buff;
 };
 
 template<typename T>
-inline BufferPtr createBuffer(BufferUsageBit usage, T data, CString name = {})
+inline BufferPtr createBuffer(BufferUsageBit usage, T pattern, U32 count, CString name = {})
 {
-	ConstWeakArray<T> arr(&data, 1);
-	return createBuffer(usage, arr, name);
+	DynamicArray<T> arr;
+	arr.resize(count, pattern);
+	return createBuffer(usage, ConstWeakArray<T>(arr), name);
 }
 
 template<typename T>
@@ -159,7 +171,7 @@ inline TexturePtr createTexture2d(const TextureInitInfo& texInit, T initialValue
 };
 
 template<typename T>
-inline void validateBuffer(BufferPtr buff, T value)
+inline void readBuffer(BufferPtr buff, DynamicArray<T>& out)
 {
 	BufferPtr tmpBuff;
 
@@ -184,13 +196,24 @@ inline void validateBuffer(BufferPtr buff, T value)
 		fence->clientWait(kMaxSecond);
 	}
 
-	const T* inData = static_cast<T*>(tmpBuff->map(0, kMaxPtrSize, BufferMapAccessBit::kRead));
-	const T* endData = inData + (tmpBuff->getSize() / sizeof(T));
-	for(; inData < endData; ++inData)
+	ANKI_ASSERT((buff->getSize() % sizeof(T)) == 0);
+	out.resize(U32(buff->getSize() / sizeof(T)));
+
+	const void* data = tmpBuff->map(0, kMaxPtrSize, BufferMapAccessBit::kRead);
+	memcpy(out.getBegin(), data, buff->getSize());
+	tmpBuff->unmap();
+}
+
+template<typename T>
+inline void validateBuffer(BufferPtr buff, T value)
+{
+	DynamicArray<T> cpuBuff;
+	readBuffer<T>(buff, cpuBuff);
+
+	for(const T& x : cpuBuff)
 	{
-		ANKI_TEST_EXPECT_EQ(*inData, value);
+		ANKI_TEST_EXPECT_EQ(x, value);
 	}
-	tmpBuff->unmap();
 }
 
 } // end namespace anki

+ 365 - 4
Tests/Gr/GrWorkGraphs.cpp

@@ -5,11 +5,48 @@
 
 #include <Tests/Framework/Framework.h>
 #include <Tests/Gr/GrCommon.h>
+#include <AnKi/Util/HighRezTimer.h>
 #include <AnKi/Gr.h>
 
+using namespace anki;
+
+static void clearSwapchain(CommandBufferPtr cmdb = CommandBufferPtr())
+{
+	const Bool continueCmdb = cmdb.isCreated();
+
+	TexturePtr presentTex = GrManager::getSingleton().acquireNextPresentableTexture();
+
+	if(!continueCmdb)
+	{
+		CommandBufferInitInfo cinit;
+		cinit.m_flags = CommandBufferFlag::kGeneralWork | CommandBufferFlag::kSmallBatch;
+		cmdb = GrManager::getSingleton().newCommandBuffer(cinit);
+	}
+
+	const TextureBarrierInfo barrier = {TextureView(presentTex.get(), TextureSubresourceDesc::all()), TextureUsageBit::kNone,
+										TextureUsageBit::kFramebufferWrite};
+	cmdb->setPipelineBarrier({&barrier, 1}, {}, {});
+
+	RenderTarget rt;
+	rt.m_textureView = TextureView(presentTex.get(), TextureSubresourceDesc::all());
+	rt.m_clearValue.m_colorf = {1.0f, F32(rand()) / F32(RAND_MAX), 1.0f, 1.0f};
+	cmdb->beginRenderPass({rt});
+	cmdb->endRenderPass();
+
+	const TextureBarrierInfo barrier2 = {TextureView(presentTex.get(), TextureSubresourceDesc::all()), TextureUsageBit::kFramebufferWrite,
+										 TextureUsageBit::kPresent};
+	cmdb->setPipelineBarrier({&barrier2, 1}, {}, {});
+
+	if(!continueCmdb)
+	{
+		cmdb->endRecording();
+		GrManager::getSingleton().submit(cmdb.get());
+	}
+}
+
 ANKI_TEST(Gr, WorkGraphHelloWorld)
 {
-	// CVarSet::getSingleton().setMultiple(Array<const Char*, 2>{"Device", "2"});
+	// CVarSet::getSingleton().setMultiple(Array<const Char*, 2>{"Device", "1"});
 	commonInit();
 
 	{
@@ -33,7 +70,7 @@ struct ThirdNodeRecord
 
 RWStructuredBuffer<uint> g_buff : register(u0);
 
-[Shader("node")] [NodeLaunch("broadcasting")] [NodeMaxDispatchGrid(16, 1, 1)] [NumThreads(16, 1, 1)]
+[Shader("node")] [NodeLaunch("broadcasting")] [NodeIsProgramEntry] [NodeMaxDispatchGrid(1, 1, 1)] [NumThreads(16, 1, 1)]
 void main(DispatchNodeInputRecord<FirstNodeRecord> inp, [MaxRecords(2)] NodeOutput<SecondNodeRecord> secondNode, uint svGroupIndex : SV_GroupIndex)
 {
 	GroupNodeOutputRecords<SecondNodeRecord> rec = secondNode.GetGroupNodeOutputRecords(2);
@@ -73,10 +110,12 @@ void thirdNode([MaxRecords(32)] GroupNodeInputRecords<ThirdNodeRecord> inp, uint
 		ShaderPtr shader = createShader(kSrc, ShaderType::kWorkGraph);
 
 		ShaderProgramInitInfo progInit;
-		progInit.m_workGraphShader = shader.get();
+		progInit.m_workGraph.m_shader = shader.get();
+		WorkGraphNodeSpecialization wgSpecialization = {"main", UVec3(4, 1, 1)};
+		progInit.m_workGraph.m_nodeSpecializations = ConstWeakArray<WorkGraphNodeSpecialization>(&wgSpecialization, 1);
 		ShaderProgramPtr prog = GrManager::getSingleton().newShaderProgram(progInit);
 
-		BufferPtr counterBuff = createBuffer(BufferUsageBit::kAllStorage | BufferUsageBit::kTransferSource, 0u, "CounterBuffer");
+		BufferPtr counterBuff = createBuffer(BufferUsageBit::kAllStorage | BufferUsageBit::kTransferSource, 0u, 1, "CounterBuffer");
 
 		BufferInitInfo scratchInit("scratch");
 		scratchInit.m_size = prog->getWorkGraphMemoryRequirements();
@@ -111,3 +150,325 @@ void thirdNode([MaxRecords(32)] GroupNodeInputRecords<ThirdNodeRecord> inp, uint
 
 	commonDestroy();
 }
+
+ANKI_TEST(Gr, WorkGraphAmplification)
+{
+	constexpr Bool benchmark = true;
+
+	// CVarSet::getSingleton().setMultiple(Array<const Char*, 2>{"Device", "2"});
+	commonInit(!benchmark);
+
+	{
+		const Char* kSrc = R"(
+struct FirstNodeRecord
+{
+	uint3 m_dispatchGrid : SV_DispatchGrid;
+};
+
+struct SecondNodeRecord
+{
+	uint3 m_dispatchGrid : SV_DispatchGrid;
+	uint m_objectIndex;
+};
+
+struct Aabb
+{
+	uint m_min;
+	uint m_max;
+};
+
+struct Object
+{
+	uint m_positionsStart; // Points to g_positions
+	uint m_positionCount;
+};
+
+RWStructuredBuffer<Aabb> g_aabbs : register(u0);
+StructuredBuffer<Object> g_objects : register(t0);
+StructuredBuffer<uint> g_positions : register(t1);
+
+#define THREAD_COUNT 64u
+
+// Operates per object
+[Shader("node")] [NodeLaunch("broadcasting")] [NodeIsProgramEntry] [NodeMaxDispatchGrid(1, 1, 1)]
+[NumThreads(THREAD_COUNT, 1, 1)]
+void main(DispatchNodeInputRecord<FirstNodeRecord> inp, [MaxRecords(THREAD_COUNT)] NodeOutput<SecondNodeRecord> computeAabb,
+		  uint svGroupIndex : SV_GroupIndex, uint svDispatchThreadId : SV_DispatchThreadId)
+{
+	GroupNodeOutputRecords<SecondNodeRecord> recs = computeAabb.GetGroupNodeOutputRecords(THREAD_COUNT);
+
+	const Object obj = g_objects[svDispatchThreadId];
+
+	recs[svGroupIndex].m_objectIndex = svDispatchThreadId;
+	recs[svGroupIndex].m_dispatchGrid = uint3((obj.m_positionCount + (THREAD_COUNT - 1)) / THREAD_COUNT, 1, 1);
+
+	recs.OutputComplete();
+}
+
+groupshared Aabb g_aabb;
+
+// Operates per position
+[Shader("node")] [NodeLaunch("broadcasting")] [NodeMaxDispatchGrid(1, 1, 1)] [NumThreads(THREAD_COUNT, 1, 1)]
+void computeAabb(DispatchNodeInputRecord<SecondNodeRecord> inp, uint svDispatchThreadId : SV_DispatchThreadId, uint svGroupIndex : SV_GroupIndex)
+{
+	const Object obj = g_objects[inp.Get().m_objectIndex];
+
+	svDispatchThreadId = min(svDispatchThreadId, obj.m_positionCount - 1);
+
+	if(svGroupIndex == 0)
+	{
+		g_aabb.m_min = 0xFFFFFFFF;
+		g_aabb.m_max = 0;
+	}
+
+	Barrier(GROUP_SHARED_MEMORY, GROUP_SCOPE | GROUP_SYNC);
+
+	const uint positionIndex = obj.m_positionsStart + svDispatchThreadId;
+
+	const uint pos = g_positions[positionIndex];
+	InterlockedMin(g_aabb.m_min, pos);
+	InterlockedMax(g_aabb.m_max, pos);
+
+	Barrier(GROUP_SHARED_MEMORY, GROUP_SCOPE | GROUP_SYNC);
+
+	InterlockedMin(g_aabbs[inp.Get().m_objectIndex].m_min, g_aabb.m_min);
+	InterlockedMax(g_aabbs[inp.Get().m_objectIndex].m_max, g_aabb.m_max);
+}
+)";
+
+		const Char* kComputeSrc = R"(
+struct Aabb
+{
+	uint m_min;
+	uint m_max;
+};
+
+struct Object
+{
+	uint m_positionsStart; // Points to g_positions
+	uint m_positionCount;
+};
+
+struct PushConsts
+{
+	uint m_objectIndex;
+	uint m_padding1;
+	uint m_padding2;
+	uint m_padding3;
+};
+
+RWStructuredBuffer<Aabb> g_aabbs : register(u0);
+StructuredBuffer<Object> g_objects : register(t0);
+StructuredBuffer<uint> g_positions : register(t1);
+
+#if defined(__spirv__)
+[[vk::push_constant]] ConstantBuffer<PushConsts> g_pushConsts;
+#else
+ConstantBuffer<PushConsts> g_pushConsts : register(b0, space3000);
+#endif
+
+#define THREAD_COUNT 64u
+
+groupshared Aabb g_aabb;
+
+[NumThreads(THREAD_COUNT, 1, 1)]
+void main(uint svDispatchThreadId : SV_DispatchThreadId, uint svGroupIndex : SV_GroupIndex)
+{
+	const Object obj = g_objects[g_pushConsts.m_objectIndex];
+
+	svDispatchThreadId = min(svDispatchThreadId, obj.m_positionCount - 1);
+
+	if(svGroupIndex == 0)
+	{
+		g_aabb.m_min = 0xFFFFFFFF;
+		g_aabb.m_max = 0;
+	}
+
+	Barrier(GROUP_SHARED_MEMORY, GROUP_SCOPE | GROUP_SYNC);
+
+	const uint positionIndex = obj.m_positionsStart + svDispatchThreadId;
+
+	const uint pos = g_positions[positionIndex];
+	InterlockedMin(g_aabb.m_min, pos);
+	InterlockedMax(g_aabb.m_max, pos);
+
+	Barrier(GROUP_SHARED_MEMORY, GROUP_SCOPE | GROUP_SYNC);
+
+	InterlockedMin(g_aabbs[g_pushConsts.m_objectIndex].m_min, g_aabb.m_min);
+	InterlockedMax(g_aabbs[g_pushConsts.m_objectIndex].m_max, g_aabb.m_max);
+}
+)";
+
+		constexpr U32 kObjectCount = 4000 * 64;
+		constexpr U32 kPositionsPerObject = 10 * 64; // 1 * 1024;
+		constexpr U32 kThreadCount = 64;
+		constexpr Bool useWorkgraphs = true;
+
+		ShaderProgramPtr prog;
+		if(useWorkgraphs)
+		{
+			ShaderPtr shader = createShader(kSrc, ShaderType::kWorkGraph);
+			ShaderProgramInitInfo progInit;
+			Array<WorkGraphNodeSpecialization, 2> specializations = {
+				{{"main", UVec3((kObjectCount + kThreadCount - 1) / kThreadCount, 1, 1)},
+				 {"computeAabb", UVec3((kPositionsPerObject + (kThreadCount - 1)) / kThreadCount, 1, 1)}}};
+			progInit.m_workGraph.m_nodeSpecializations = specializations;
+			progInit.m_workGraph.m_shader = shader.get();
+			prog = GrManager::getSingleton().newShaderProgram(progInit);
+		}
+		else
+		{
+			ShaderPtr shader = createShader(kComputeSrc, ShaderType::kCompute);
+			ShaderProgramInitInfo progInit;
+			progInit.m_computeShader = shader.get();
+			prog = GrManager::getSingleton().newShaderProgram(progInit);
+		}
+
+		struct Aabb
+		{
+			U32 m_min = kMaxU32;
+			U32 m_max = 0;
+
+			Bool operator==(const Aabb&) const = default;
+		};
+
+		struct Object
+		{
+			U32 m_positionsStart; // Points to g_positions
+			U32 m_positionCount;
+		};
+
+		// Objects
+		DynamicArray<Object> objects;
+		objects.resize(kObjectCount);
+		U32 positionCount = 0;
+		for(Object& obj : objects)
+		{
+			obj.m_positionsStart = positionCount;
+			obj.m_positionCount = kPositionsPerObject;
+
+			positionCount += obj.m_positionCount;
+		}
+
+		printf("Obj count %u, pos count %u\n", kObjectCount, positionCount);
+
+		BufferPtr objBuff = createBuffer(BufferUsageBit::kStorageComputeRead, ConstWeakArray(objects), "Objects");
+
+		// AABBs
+		BufferPtr aabbsBuff = createBuffer(BufferUsageBit::kStorageComputeWrite, Aabb(), kObjectCount, "AABBs");
+
+		// Positions
+		GrDynamicArray<U32> positions;
+		positions.resize(positionCount);
+		positionCount = 0;
+		for(U32 iobj = 0; iobj < kObjectCount; ++iobj)
+		{
+			const Object& obj = objects[iobj];
+
+			const U32 min = getRandomRange<U32>(0, kMaxU32 / 2 - 1);
+			const U32 max = getRandomRange<U32>(kMaxU32 / 2, kMaxU32);
+
+			for(U32 ipos = obj.m_positionsStart; ipos < obj.m_positionsStart + obj.m_positionCount; ++ipos)
+			{
+				positions[ipos] = getRandomRange<U32>(min, max);
+
+				positions[ipos] = iobj;
+			}
+
+			positionCount += obj.m_positionCount;
+		}
+
+		BufferPtr posBuff = createBuffer(BufferUsageBit::kStorageComputeRead, ConstWeakArray(positions), "Positions");
+
+		// Execute
+		for(U32 i = 0; i < ((benchmark) ? 200 : 1); ++i)
+		{
+			[[maybe_unused]] const Error err = Input::getSingleton().handleEvents();
+
+			BufferPtr scratchBuff;
+			if(useWorkgraphs)
+			{
+				BufferInitInfo scratchInit("scratch");
+				scratchInit.m_size = prog->getWorkGraphMemoryRequirements();
+				scratchInit.m_usage = BufferUsageBit::kAllStorage;
+				scratchBuff = GrManager::getSingleton().newBuffer(scratchInit);
+			}
+
+			const Second timeA = HighRezTimer::getCurrentTime();
+
+			CommandBufferPtr cmdb;
+			if(useWorkgraphs)
+			{
+				struct FirstNodeRecord
+				{
+					UVec3 m_gridSize;
+				};
+
+				Array<FirstNodeRecord, 1> records;
+				records[0].m_gridSize = UVec3((objects.getSize() + kThreadCount - 1) / kThreadCount, 1, 1);
+
+				cmdb = GrManager::getSingleton().newCommandBuffer(
+					CommandBufferInitInfo(CommandBufferFlag::kSmallBatch | CommandBufferFlag::kGeneralWork));
+				cmdb->bindShaderProgram(prog.get());
+				cmdb->bindStorageBuffer(ANKI_REG(u0), BufferView(aabbsBuff.get()));
+				cmdb->bindStorageBuffer(ANKI_REG(t0), BufferView(objBuff.get()));
+				cmdb->bindStorageBuffer(ANKI_REG(t1), BufferView(posBuff.get()));
+				cmdb->dispatchGraph(BufferView(scratchBuff.get()), records.getBegin(), records.getSize(), sizeof(records[0]));
+			}
+			else
+			{
+				cmdb = GrManager::getSingleton().newCommandBuffer(CommandBufferInitInfo(CommandBufferFlag::kGeneralWork));
+				cmdb->bindShaderProgram(prog.get());
+				cmdb->bindStorageBuffer(ANKI_REG(u0), BufferView(aabbsBuff.get()));
+				cmdb->bindStorageBuffer(ANKI_REG(t0), BufferView(objBuff.get()));
+				cmdb->bindStorageBuffer(ANKI_REG(t1), BufferView(posBuff.get()));
+
+				for(U32 iobj = 0; iobj < kObjectCount; ++iobj)
+				{
+					const UVec4 pc(iobj);
+					cmdb->setPushConstants(&pc, sizeof(pc));
+
+					cmdb->dispatchCompute((objects[iobj].m_positionCount + kThreadCount - 1) / kThreadCount, 1, 1);
+				}
+			}
+
+			clearSwapchain(cmdb);
+
+			cmdb->endRecording();
+
+			const Second timeB = HighRezTimer::getCurrentTime();
+
+			FencePtr fence;
+			GrManager::getSingleton().submit(cmdb.get(), {}, &fence);
+			fence->clientWait(kMaxSecond);
+
+			GrManager::getSingleton().swapBuffers();
+
+			const Second timeC = HighRezTimer::getCurrentTime();
+
+			printf("GPU time: %fms, cmdb build time: %fms\n", (timeC - timeB) * 1000.0, (timeB - timeA) * 1000.0);
+		}
+
+		// Check
+		DynamicArray<Aabb> aabbs;
+		readBuffer(aabbsBuff, aabbs);
+		for(U32 i = 0; i < kObjectCount; ++i)
+		{
+			const Object& obj = objects[i];
+			Aabb aabb;
+			for(U32 ipos = obj.m_positionsStart; ipos < obj.m_positionsStart + obj.m_positionCount; ++ipos)
+			{
+				aabb.m_min = min(aabb.m_min, positions[ipos]);
+				aabb.m_max = max(aabb.m_max, positions[ipos]);
+			}
+
+			if(aabb != aabbs[i])
+			{
+				printf("%u: %u %u | %u %u\n", i, aabb.m_min, aabbs[i].m_min, aabb.m_max, aabbs[i].m_max);
+			}
+			ANKI_TEST_EXPECT_EQ(aabb, aabbs[i]);
+		}
+	}
+
+	commonDestroy();
+}