Browse Source

Optimize the local light binning

Panagiotis Christopoulos Charitos 4 months ago
parent
commit
e6f942965d

+ 16 - 0
AnKi/Renderer/LightShading.cpp

@@ -79,6 +79,8 @@ void LightShading::run(const RenderingContext& ctx, RenderPassWorkContext& rgrap
 
 	// Do light shading first
 	{
+		cmdb.pushDebugMarker("LightShading", Vec3(0.0f, 1.0f, 1.0f));
+
 		cmdb.bindShaderProgram(m_lightShading.m_grProg.get());
 		cmdb.setDepthWrite(false);
 
@@ -110,10 +112,14 @@ void LightShading::run(const RenderingContext& ctx, RenderPassWorkContext& rgrap
 
 		// Draw
 		drawQuad(cmdb);
+
+		cmdb.popDebugMarker();
 	}
 
 	// Skybox
 	{
+		cmdb.pushDebugMarker("Skybox", Vec3(0.0f, 1.0f, 1.0f));
+
 		cmdb.setDepthCompareOperation(CompareOperation::kEqual);
 
 		const SkyboxComponent* sky = SceneGraph::getSingleton().getSkybox();
@@ -171,10 +177,14 @@ void LightShading::run(const RenderingContext& ctx, RenderPassWorkContext& rgrap
 
 		// Restore state
 		cmdb.setDepthCompareOperation(CompareOperation::kLess);
+
+		cmdb.popDebugMarker();
 	}
 
 	// Apply the fog
 	{
+		cmdb.pushDebugMarker("LightApplyFog", Vec3(0.0f, 1.0f, 1.0f));
+
 		cmdb.bindShaderProgram(m_applyFog.m_grProg.get());
 
 		// Bind all
@@ -206,6 +216,8 @@ void LightShading::run(const RenderingContext& ctx, RenderPassWorkContext& rgrap
 
 		// Reset state
 		cmdb.setBlendFactors(0, BlendFactor::kOne, BlendFactor::kZero);
+
+		cmdb.popDebugMarker();
 	}
 
 	// Debug stuff
@@ -216,6 +228,8 @@ void LightShading::run(const RenderingContext& ctx, RenderPassWorkContext& rgrap
 
 	// Forward shading last
 	{
+		cmdb.pushDebugMarker("ForwardShading", Vec3(0.0f, 1.0f, 1.0f));
+
 		if(enableVrs)
 		{
 			cmdb.setVrsRate(VrsRate::k2x2);
@@ -228,6 +242,8 @@ void LightShading::run(const RenderingContext& ctx, RenderPassWorkContext& rgrap
 			// Restore
 			cmdb.setVrsRate(VrsRate::k1x1);
 		}
+
+		cmdb.popDebugMarker();
 	}
 }
 

+ 23 - 7
AnKi/Renderer/Utils/GpuVisibility.cpp

@@ -1222,6 +1222,12 @@ void GpuVisibilityLocalLights::populateRenderGraph(GpuVisibilityLocalLightsInput
 	const BufferView lightIndexOffsetsPerCellBuff = allocateStructuredBuffer<U32>(cellCount);
 	const BufferView lightIndexCountBuff = allocateStructuredBuffer<U32>(1);
 	const BufferView lightIndexListBuff = allocateStructuredBuffer<U32>(in.m_lightIndexListSize);
+	const BufferView threadgroupCountBuff = allocateStructuredBuffer<U32>(1);
+
+	constexpr U32 kPrefixSumThreadCount = 1024; // Common for most GPUs
+	constexpr U32 kPrefixSumElementCountPerThreadgroup = kPrefixSumThreadCount * 2;
+	const BufferView groupWidePrefixSumsBuff =
+		allocateStructuredBuffer<U32>((cellCount + kPrefixSumElementCountPerThreadgroup - 1) / kPrefixSumElementCountPerThreadgroup);
 
 	const BufferHandle dep = rgraph.importBuffer(lightIndexCountBuff, BufferUsageBit::kNone);
 
@@ -1235,7 +1241,8 @@ void GpuVisibilityLocalLights::populateRenderGraph(GpuVisibilityLocalLightsInput
 	consts.m_maxLightIndices = in.m_lightIndexListSize;
 	consts.m_gridVolumeMin = out.m_lightGridMin;
 	consts.m_gridVolumeSize = gridSize;
-	consts.m_cellCounts = Vec3(in.m_cellCounts);
+	consts.m_cellCounts = in.m_cellCounts;
+	consts.m_cellCount = cellCount;
 
 	// Setup
 	{
@@ -1243,7 +1250,8 @@ void GpuVisibilityLocalLights::populateRenderGraph(GpuVisibilityLocalLightsInput
 
 		pass.newBufferDependency(dep, BufferUsageBit::kUavCompute);
 
-		pass.setWork([this, lightIndexCountsPerCellBuff, lightIndexCountBuff, cellCount](RenderPassWorkContext& rgraph) {
+		pass.setWork([this, lightIndexCountsPerCellBuff, lightIndexCountBuff, cellCount, threadgroupCountBuff,
+					  groupWidePrefixSumsBuff](RenderPassWorkContext& rgraph) {
 			ANKI_TRACE_SCOPED_EVENT(GpuVisibilityLocalLightsSetup);
 			CommandBuffer& cmdb = *rgraph.m_commandBuffer;
 
@@ -1251,6 +1259,8 @@ void GpuVisibilityLocalLights::populateRenderGraph(GpuVisibilityLocalLightsInput
 
 			cmdb.bindUav(0, 0, lightIndexCountsPerCellBuff);
 			cmdb.bindUav(1, 0, lightIndexCountBuff);
+			cmdb.bindUav(2, 0, groupWidePrefixSumsBuff);
+			cmdb.bindUav(3, 0, threadgroupCountBuff);
 
 			dispatchPPCompute(cmdb, 64, 1, cellCount, 1);
 		});
@@ -1265,7 +1275,8 @@ void GpuVisibilityLocalLights::populateRenderGraph(GpuVisibilityLocalLightsInput
 		pass.newBufferDependency(dep, BufferUsageBit::kUavCompute);
 		pass.newBufferDependency(getRenderer().getGpuSceneBufferHandle(), BufferUsageBit::kSrvCompute);
 
-		pass.setWork([this, lightIndexCountsPerCellBuff, lightIndexCountBuff, consts](RenderPassWorkContext& rgraph) {
+		pass.setWork([this, lightIndexCountsPerCellBuff, lightIndexCountBuff, consts, threadgroupCountBuff,
+					  groupWidePrefixSumsBuff](RenderPassWorkContext& rgraph) {
 			ANKI_TRACE_SCOPED_EVENT(GpuVisibilityLocalLightsCount);
 
 			const GpuSceneArrays::Light& lights = GpuSceneArrays::Light::getSingleton();
@@ -1278,10 +1289,12 @@ void GpuVisibilityLocalLights::populateRenderGraph(GpuVisibilityLocalLightsInput
 
 			cmdb.bindUav(0, 0, lightIndexCountsPerCellBuff);
 			cmdb.bindUav(1, 0, lightIndexCountBuff);
+			cmdb.bindUav(2, 0, groupWidePrefixSumsBuff);
+			cmdb.bindUav(3, 0, threadgroupCountBuff);
 
 			cmdb.setFastConstants(&consts, sizeof(consts));
 
-			dispatchPPCompute(cmdb, 64, 1, lights.getElementCount(), 1);
+			dispatchPPCompute(cmdb, 64, 1, consts.m_cellCount, 1);
 		});
 	}
 
@@ -1292,19 +1305,22 @@ void GpuVisibilityLocalLights::populateRenderGraph(GpuVisibilityLocalLightsInput
 
 		pass.newBufferDependency(dep, BufferUsageBit::kUavCompute);
 
-		pass.setWork([this, lightIndexCountsPerCellBuff, lightIndexOffsetsPerCellBuff, lightIndexCountBuff, consts](RenderPassWorkContext& rgraph) {
+		pass.setWork([this, lightIndexCountsPerCellBuff, lightIndexOffsetsPerCellBuff, lightIndexCountBuff, consts,
+					  groupWidePrefixSumsBuff](RenderPassWorkContext& rgraph) {
 			ANKI_TRACE_SCOPED_EVENT(GpuVisibilityLocalLightsPrefixSum);
 			CommandBuffer& cmdb = *rgraph.m_commandBuffer;
 
 			cmdb.bindShaderProgram(m_prefixSumGrProg.get());
 
+			cmdb.bindSrv(0, 0, groupWidePrefixSumsBuff);
+
 			cmdb.bindUav(0, 0, lightIndexCountsPerCellBuff);
 			cmdb.bindUav(1, 0, lightIndexOffsetsPerCellBuff);
 			cmdb.bindUav(2, 0, lightIndexCountBuff);
 
 			cmdb.setFastConstants(&consts, sizeof(consts));
 
-			cmdb.dispatchCompute(1, 1, 1);
+			cmdb.dispatchCompute((consts.m_cellCount + kPrefixSumElementCountPerThreadgroup - 1) / kPrefixSumElementCountPerThreadgroup, 1, 1);
 		});
 	}
 
@@ -1333,7 +1349,7 @@ void GpuVisibilityLocalLights::populateRenderGraph(GpuVisibilityLocalLightsInput
 
 			cmdb.setFastConstants(&consts, sizeof(consts));
 
-			dispatchPPCompute(cmdb, 64, 1, lights.getElementCount(), 1);
+			dispatchPPCompute(cmdb, 64, 1, consts.m_cellCount, 1);
 		});
 	}
 }

+ 6 - 1
AnKi/ShaderCompiler/Dxc.cpp

@@ -256,11 +256,16 @@ static Error compileHlsl(CString src, ShaderType shaderType, Bool compileWith16b
 
 	CComPtr<IDxcBlob> pShader = nullptr;
 	ANKI_DXC_CHECK(pResults->GetOutput(DXC_OUT_OBJECT, IID_PPV_ARGS(&pShader), nullptr));
-	if(pShader != nullptr)
+	if(pShader != nullptr && pShader->GetBufferSize() > 0)
 	{
 		bin.resize(U32(pShader->GetBufferSize()));
 		memcpy(bin.getBegin(), pShader->GetBufferPointer(), pShader->GetBufferSize());
 	}
+	else
+	{
+		ANKI_SHADER_COMPILER_LOGE("DXC returned an empty binary blob");
+		return Error::kFunctionFailed;
+	}
 
 	return Error::kNone;
 }

+ 196 - 171
AnKi/Shaders/GpuVisibilityLocalLights.ankiprog

@@ -13,11 +13,15 @@
 #pragma anki technique PrefixSum comp
 #pragma anki technique Fill comp
 
-#include <AnKi/Shaders/Common.hlsl>
+#include <AnKi/Shaders/Functions.hlsl>
 #include <AnKi/Shaders/Include/GpuSceneTypes.h>
 #include <AnKi/Shaders/Include/GpuVisibilityTypes.h>
 #include <AnKi/Shaders/VisibilityAndCollisionFunctions.hlsl>
 
+constexpr U32 kPrefixSumThreadCount = 1024; // Common for most GPUs
+constexpr U32 kPrefixSumElementCountPerThreadgroup =
+	kPrefixSumThreadCount * 2; // Now many elements a single threadgroup can calculate their prfix sum
+
 Bool insideFrustum(Vec4 planes[5], Vec3 aabbMin, Vec3 aabbMax)
 {
 	[unroll] for(U32 i = 0; i < 5; ++i)
@@ -31,106 +35,93 @@ Bool insideFrustum(Vec4 planes[5], Vec3 aabbMin, Vec3 aabbMax)
 	return true;
 }
 
-template<typename TFunc>
-void lightVsCellVisibility(StructuredBuffer<GpuSceneLight> lights, U32 lightIdx, GpuVisibilityLocalLightsConsts consts,
-						   RWStructuredBuffer<U32> lightIndexCount, Bool detailedTests, TFunc binLightToCellFunc)
+template<typename TFunc, typename TFunc2>
+void lightVsCellVisibility(StructuredBuffer<GpuSceneLight> lights, U32 cellIdx, GpuVisibilityLocalLightsConsts consts,
+						   RWStructuredBuffer<U32> lightIndexCount, Bool detailedTests, TFunc binLightToCellFunc, TFunc2 informLightIndexCountFunc)
 {
-	const U32 lightCount = getStructuredBufferElementCount(lights);
-	if(lightIdx >= lightCount)
+	if(cellIdx >= consts.m_cellCount)
 	{
 		return;
 	}
 
-	const GpuSceneLight light = SBUFF(lights, lightIdx);
+	UVec3 cellId;
+	unflatten3dArrayIndex(consts.m_cellCounts.z, consts.m_cellCounts.y, consts.m_cellCounts.x, cellIdx, cellId.z, cellId.y, cellId.x);
+	const Vec3 cellMin = cellId * consts.m_cellSize + consts.m_gridVolumeMin;
+	const Vec3 cellMax = cellMin + consts.m_cellSize;
 
-	// Get the light bounds
-	Vec3 worldLightAabbMin;
-	Vec3 worldLightAabbMax;
-	if((U32)light.m_flags & (U32)GpuSceneLightFlag::kPointLight)
-	{
-		worldLightAabbMin = light.m_position - light.m_radius;
-		worldLightAabbMax = light.m_position + light.m_radius;
-	}
-	else
+	U32 visibleLightCount = 0;
+	const U32 lightCount = getStructuredBufferElementCount(lights);
+	for(U32 i = 0; i < lightCount; ++i)
 	{
-		worldLightAabbMin = light.m_position;
-		worldLightAabbMax = light.m_position;
+		const GpuSceneLight light = lights[i];
 
-		[unroll] for(U32 i = 0; i < 4; ++i)
+		// Get the light bounds
+		Vec3 worldLightAabbMin;
+		Vec3 worldLightAabbMax;
+		if((U32)light.m_flags & (U32)GpuSceneLightFlag::kPointLight)
 		{
-			worldLightAabbMin = min(worldLightAabbMin, light.m_edgePoints[i]);
-			worldLightAabbMax = max(worldLightAabbMax, light.m_edgePoints[i]);
+			worldLightAabbMin = light.m_position - light.m_radius;
+			worldLightAabbMax = light.m_position + light.m_radius;
 		}
-	}
+		else
+		{
+			worldLightAabbMin = light.m_position;
+			worldLightAabbMax = light.m_position;
 
-	Vec3 localLightAabbMin = worldLightAabbMin - consts.m_gridVolumeMin;
-	localLightAabbMin = clamp(localLightAabbMin, 0.0, consts.m_gridVolumeSize - kEpsilonF32);
+			[unroll] for(U32 i = 0; i < 4; ++i)
+			{
+				worldLightAabbMin = min(worldLightAabbMin, light.m_edgePoints[i]);
+				worldLightAabbMax = max(worldLightAabbMax, light.m_edgePoints[i]);
+			}
+		}
 
-	Vec3 localLightAabbMax = worldLightAabbMax - consts.m_gridVolumeMin;
-	localLightAabbMax = clamp(localLightAabbMax, 0.0, consts.m_gridVolumeSize - kEpsilonF32);
+		if(!aabbAabbOverlap(worldLightAabbMin, worldLightAabbMax, cellMin, cellMax))
+		{
+			continue;
+		}
 
-	if(any(localLightAabbMin == localLightAabbMax))
-	{
-		// Outside the volume
-		return;
-	}
+		if(detailedTests)
+		{
+			Vec4 spotLightPlanes[5];
+			if((U32)light.m_flags & (U32)GpuSceneLightFlag::kSpotLight)
+			{
+				const Vec3 pe = light.m_position;
+				const Vec3 p0 = light.m_edgePoints[0];
+				const Vec3 p1 = light.m_edgePoints[1];
+				const Vec3 p2 = light.m_edgePoints[2];
+				const Vec3 p3 = light.m_edgePoints[3];
+				spotLightPlanes[0] = computePlane(pe, p0, p3);
+				spotLightPlanes[1] = computePlane(pe, p1, p0);
+				spotLightPlanes[2] = computePlane(pe, p2, p1);
+				spotLightPlanes[3] = computePlane(pe, p3, p2);
+				spotLightPlanes[4] = computePlane(p3, p0, p1);
+			}
 
-	Vec4 spotLightPlanes[5];
-	if((U32)light.m_flags & (U32)GpuSceneLightFlag::kSpotLight)
-	{
-		const Vec3 pe = light.m_position;
-		const Vec3 p0 = light.m_edgePoints[0];
-		const Vec3 p1 = light.m_edgePoints[1];
-		const Vec3 p2 = light.m_edgePoints[2];
-		const Vec3 p3 = light.m_edgePoints[3];
-		spotLightPlanes[0] = computePlane(pe, p0, p3);
-		spotLightPlanes[1] = computePlane(pe, p1, p0);
-		spotLightPlanes[2] = computePlane(pe, p2, p1);
-		spotLightPlanes[3] = computePlane(pe, p3, p2);
-		spotLightPlanes[4] = computePlane(p3, p0, p1);
-	}
+			if((U32)light.m_flags & (U32)GpuSceneLightFlag::kPointLight && !aabbSphereOverlap(cellMin, cellMax, light.m_position, light.m_radius))
+			{
+				continue;
+			}
+			else if((U32)light.m_flags & (U32)GpuSceneLightFlag::kSpotLight && !insideFrustum(spotLightPlanes, cellMin, cellMax))
+			{
+				continue;
+			}
+		}
 
-	const Vec3 localLightFirstCell = floor(localLightAabbMin / consts.m_cellSize);
-	const Vec3 localLightEndCell = ceil(localLightAabbMax / consts.m_cellSize);
+		U32 count;
+		InterlockedAdd(SBUFF(lightIndexCount, 0), 1, count);
+		++count;
 
-	for(F32 x = localLightFirstCell.x; x < localLightEndCell.x; x += 1.0)
-	{
-		for(F32 y = localLightFirstCell.y; y < localLightEndCell.y; y += 1.0)
+		if(count > consts.m_maxLightIndices)
 		{
-			for(F32 z = localLightFirstCell.z; z < localLightEndCell.z; z += 1.0)
-			{
-				const Vec3 cellMin = Vec3(x, y, z) * consts.m_cellSize + consts.m_gridVolumeMin;
-				const Vec3 cellMax = cellMin + consts.m_cellSize;
-
-				if(detailedTests)
-				{
-					if((U32)light.m_flags & (U32)GpuSceneLightFlag::kPointLight
-					   && !aabbSphereOverlap(cellMin, cellMax, light.m_position, light.m_radius))
-					{
-						continue;
-					}
-					else if((U32)light.m_flags & (U32)GpuSceneLightFlag::kSpotLight && !insideFrustum(spotLightPlanes, cellMin, cellMax))
-					{
-						continue;
-					}
-				}
-
-				U32 count;
-				InterlockedAdd(SBUFF(lightIndexCount, 0), 1, count);
-				++count;
-
-				if(count > consts.m_maxLightIndices)
-				{
-					// Light index list is too small
-					break;
-				}
-
-				const F32 cellIdx = z * consts.m_cellCounts.y * consts.m_cellCounts.x + y * consts.m_cellCounts.x + x;
-
-				binLightToCellFunc(cellIdx, lightIdx);
-			}
+			// Light index list is too small
+			break;
 		}
+
+		++visibleLightCount;
+		binLightToCellFunc(cellIdx, i);
 	}
+
+	informLightIndexCountFunc(cellIdx, visibleLightCount);
 }
 
 // ===========================================================================
@@ -140,21 +131,26 @@ void lightVsCellVisibility(StructuredBuffer<GpuSceneLight> lights, U32 lightIdx,
 
 RWStructuredBuffer<U32> g_lightIndexCountsPerCell : register(u0);
 RWStructuredBuffer<U32> g_lightIndexCount : register(u1);
-
-ANKI_FAST_CONSTANTS(GpuVisibilityLocalLightsConsts, g_consts)
+RWStructuredBuffer<U32> g_groupWidePrefixSums : register(u2);
+RWStructuredBuffer<U32> g_threadgroupCount : register(u3);
 
 [numthreads(64, 1, 1)] void main(COMPUTE_ARGS)
 {
 	if(svDispatchThreadId.x == 0)
 	{
 		SBUFF(g_lightIndexCount, 0) = 0;
+		SBUFF(g_threadgroupCount, 0) = 0;
 	}
 
-	const U32 elementCount = getStructuredBufferElementCount(g_lightIndexCountsPerCell);
-	if(svDispatchThreadId.x < elementCount)
+	if(svDispatchThreadId.x < getStructuredBufferElementCount(g_lightIndexCountsPerCell))
 	{
 		SBUFF(g_lightIndexCountsPerCell, svDispatchThreadId.x) = 0;
 	}
+
+	if(svDispatchThreadId.x < getStructuredBufferElementCount(g_groupWidePrefixSums))
+	{
+		SBUFF(g_groupWidePrefixSums, svDispatchThreadId.x) = 0;
+	}
 }
 #endif
 
@@ -170,6 +166,8 @@ StructuredBuffer<GpuSceneLight> g_lights : register(t0);
 
 RWStructuredBuffer<U32> g_lightIndexCountsPerCell : register(u0);
 RWStructuredBuffer<U32> g_lightIndexCount : register(u1);
+RWStructuredBuffer<U32> g_groupWidePrefixSums : register(u2);
+RWStructuredBuffer<U32> g_threadgroupCount : register(u3);
 
 ANKI_FAST_CONSTANTS(GpuVisibilityLocalLightsConsts, g_consts)
 
@@ -181,10 +179,50 @@ struct Func
 	}
 };
 
-[numthreads(64, 1, 1)] void main(COMPUTE_ARGS)
+struct Func2
+{
+	void operator()(U32 cellIdx, U32 visibleLightCount)
+	{
+		if(visibleLightCount)
+		{
+			const U32 group = cellIdx / kPrefixSumElementCountPerThreadgroup;
+			InterlockedAdd(SBUFF(g_groupWidePrefixSums, group), visibleLightCount);
+		}
+	}
+};
+
+constexpr U32 kThreadCount = 64;
+
+[numthreads(kThreadCount, 1, 1)] void main(COMPUTE_ARGS)
 {
 	Func func;
-	lightVsCellVisibility(g_lights, svDispatchThreadId.x, g_consts, g_lightIndexCount, false, func);
+	Func2 func2;
+	lightVsCellVisibility(g_lights, svDispatchThreadId.x, g_consts, g_lightIndexCount, false, func, func2);
+
+	// Sync to make sure all the atomic ops have finished before the following code reads them
+	AllMemoryBarrierWithGroupSync();
+
+	// Compute the group prefix sum
+	if(svGroupIndex == 0)
+	{
+		U32 threadgroupIdx;
+		InterlockedAdd(SBUFF(g_threadgroupCount, 0), 1, threadgroupIdx);
+		const U32 threadgroupCount = (g_consts.m_cellCount + kThreadCount - 1) / kThreadCount;
+		const Bool lastThreadgroupExecuting = (threadgroupIdx + 1 == threadgroupCount);
+
+		if(lastThreadgroupExecuting)
+		{
+			const U32 prefixSumGroupCount = getStructuredBufferElementCount(g_groupWidePrefixSums);
+
+			U32 count = 0;
+			for(U32 i = 0; i < prefixSumGroupCount; ++i)
+			{
+				const U32 c = SBUFF(g_groupWidePrefixSums, i);
+				SBUFF(g_groupWidePrefixSums, i) = count;
+				count += c;
+			}
+		}
+	}
 }
 #endif
 
@@ -197,8 +235,7 @@ struct Func
 
 #if NOT_ZERO(ANKI_TECHNIQUE_PrefixSum)
 
-constexpr U32 kThreadCount = 1024; // Common for most GPUs
-constexpr U32 kMaxElementCountPerIteration = kThreadCount * 2;
+StructuredBuffer<U32> g_groupWidePrefixSums : register(t0);
 
 RWStructuredBuffer<U32> g_inputElements : register(u0); // It's the g_lightIndexCountsPerCell. RW because we want to zero it at the end
 
@@ -209,115 +246,95 @@ RWStructuredBuffer<U32> g_lightIndexCount : register(u2);
 
 ANKI_FAST_CONSTANTS(GpuVisibilityLocalLightsConsts, g_consts)
 
-groupshared U32 g_tmp[kMaxElementCountPerIteration];
-groupshared U32 g_valueSum;
+groupshared U32 g_tmp[kPrefixSumElementCountPerThreadgroup];
 
-[numthreads(kThreadCount, 1, 1)] void main(COMPUTE_ARGS)
+[numthreads(kPrefixSumThreadCount, 1, 1)] void main(COMPUTE_ARGS)
 {
-	const U32 elementCount = g_consts.m_cellCounts.x * g_consts.m_cellCounts.y * g_consts.m_cellCounts.z;
-	const U32 iterationCount = (elementCount + kMaxElementCountPerIteration - 1) / kMaxElementCountPerIteration;
+	const U32 elementCount = g_consts.m_cellCount;
 
 	const U32 tid = svGroupIndex;
+	const U32 group = svGroupId.x;
 
-	g_valueSum = 0; // No need for barrier, there are plenty bellow
+	const U32 firstElement = group * kPrefixSumElementCountPerThreadgroup;
+	const U32 endElement = min((group + 1) * kPrefixSumElementCountPerThreadgroup, elementCount);
 
-	for(U32 it = 0; it < iterationCount; ++it)
-	{
-		GroupMemoryBarrierWithGroupSync(); // Barrier because of the loop
+	// Load input into shared memory
+	const U32 inIdx1 = 2 * tid + firstElement;
+	const U32 value1 = (inIdx1 < endElement) ? SBUFF(g_inputElements, inIdx1) : 0;
+	g_tmp[2 * tid] = value1;
 
-		const U32 firstElement = it * kMaxElementCountPerIteration;
-		const U32 endElement = min((it + 1) * kMaxElementCountPerIteration, elementCount);
+	const U32 inIdx2 = 2 * tid + 1 + firstElement;
+	const U32 value2 = (inIdx2 < endElement) ? SBUFF(g_inputElements, inIdx2) : 0;
+	g_tmp[2 * tid + 1] = value2;
 
-		// Load input into shared memory
-		const U32 inIdx1 = 2 * tid + firstElement;
-		const U32 value1 = (inIdx1 < endElement) ? SBUFF(g_inputElements, inIdx1) : 0;
-		g_tmp[2 * tid] = value1;
-
-		const U32 inIdx2 = 2 * tid + 1 + firstElement;
-		const U32 value2 = (inIdx2 < endElement) ? SBUFF(g_inputElements, inIdx2) : 0;
-		g_tmp[2 * tid + 1] = value2;
-
-		// Perform reduction
-		U32 offset = 1;
-		for(U32 d = kMaxElementCountPerIteration >> 1; d > 0; d >>= 1)
-		{
-			GroupMemoryBarrierWithGroupSync();
+	// Since g_inputElements have been read reset them to be reused in the next job
+	if(inIdx1 < endElement)
+	{
+		SBUFF(g_inputElements, inIdx1) = 0;
+	}
 
-			if(tid < d)
-			{
-				const U32 ai = offset * (2 * tid + 1) - 1;
-				const U32 bi = offset * (2 * tid + 2) - 1;
-				g_tmp[bi] += g_tmp[ai];
-			}
+	if(inIdx2 < endElement)
+	{
+		SBUFF(g_inputElements, inIdx2) = 0;
+	}
 
-			offset *= 2;
-		}
+	// Perform reduction
+	U32 offset = 1;
+	for(U32 d = kPrefixSumElementCountPerThreadgroup >> 1; d > 0; d >>= 1)
+	{
+		GroupMemoryBarrierWithGroupSync();
 
-		// Clear the last element
-		if(tid == 0)
+		if(tid < d)
 		{
-			g_tmp[kMaxElementCountPerIteration - 1] = 0;
+			const U32 ai = offset * (2 * tid + 1) - 1;
+			const U32 bi = offset * (2 * tid + 2) - 1;
+			g_tmp[bi] += g_tmp[ai];
 		}
 
-		// Perform downsweep and build scan
-		for(U32 d = 1; d < kMaxElementCountPerIteration; d *= 2)
-		{
-			offset >>= 1;
-
-			GroupMemoryBarrierWithGroupSync();
+		offset *= 2;
+	}
 
-			if(tid < d)
-			{
-				const U32 ai = offset * (2 * tid + 1) - 1;
-				const U32 bi = offset * (2 * tid + 2) - 1;
-				const U32 t = g_tmp[ai];
-				g_tmp[ai] = g_tmp[bi];
-				g_tmp[bi] += t;
-			}
-		}
+	// Clear the last element
+	if(tid == 0)
+	{
+		g_tmp[kPrefixSumElementCountPerThreadgroup - 1] = 0;
+	}
 
-		// Good time to read it
-		const U32 valueSum = g_valueSum;
+	// Perform downsweep and build scan
+	for(U32 d = 1; d < kPrefixSumElementCountPerThreadgroup; d *= 2)
+	{
+		offset >>= 1;
 
 		GroupMemoryBarrierWithGroupSync();
 
-		// Write to output buffer
-		if(inIdx1 < endElement)
+		if(tid < d)
 		{
-			SBUFF(g_outputElements, inIdx1) = g_tmp[2 * tid] + valueSum;
+			const U32 ai = offset * (2 * tid + 1) - 1;
+			const U32 bi = offset * (2 * tid + 2) - 1;
+			const U32 t = g_tmp[ai];
+			g_tmp[ai] = g_tmp[bi];
+			g_tmp[bi] += t;
 		}
+	}
 
-		if(inIdx2 < endElement)
-		{
-			SBUFF(g_outputElements, inIdx2) = g_tmp[2 * tid + 1] + valueSum;
-		}
+	GroupMemoryBarrierWithGroupSync();
 
-		// Good time to update it
-		if(value1 + value2 > 0)
-		{
-			InterlockedAdd(g_valueSum, value1 + value2);
-		}
+	// Write to output buffer
+	const U32 groupPrefixSum = SBUFF(g_groupWidePrefixSums, group);
+	if(inIdx1 < endElement)
+	{
+		SBUFF(g_outputElements, inIdx1) = g_tmp[2 * tid] + groupPrefixSum;
 	}
 
-	// Abuse this compute job to also reset some buffers
-	if(tid == 0)
+	if(inIdx2 < endElement)
 	{
-		SBUFF(g_lightIndexCount, 0) = 0;
+		SBUFF(g_outputElements, inIdx2) = g_tmp[2 * tid + 1] + groupPrefixSum;
 	}
 
+	// Abuse this compute job to also reset that buffer
+	if(svDispatchThreadId.x == 0)
 	{
-		const U32 elementsPerThread = (elementCount + kThreadCount - 1) / kThreadCount;
-
-		for(U32 i = 0; i < elementsPerThread; ++i)
-		{
-			const U32 idx = tid * elementsPerThread + i;
-			if(idx >= elementCount)
-			{
-				break;
-			}
-
-			SBUFF(g_inputElements, idx) = 0;
-		}
+		SBUFF(g_lightIndexCount, 0) = 0;
 	}
 }
 #endif
@@ -352,10 +369,18 @@ struct Func
 	}
 };
 
+struct Func2
+{
+	void operator()(U32 clusterIdx, U32 visibleLightCount)
+	{
+	}
+};
+
 [numthreads(64, 1, 1)] void main(COMPUTE_ARGS)
 {
 	Func func;
-	lightVsCellVisibility(g_lights, svDispatchThreadId.x, g_consts, g_lightIndexCount, true, func);
+	Func2 func2;
+	lightVsCellVisibility(g_lights, svDispatchThreadId.x, g_consts, g_lightIndexCount, true, func, func2);
 }
 
 #endif

+ 2 - 2
AnKi/Shaders/Include/GpuVisibilityTypes.h

@@ -115,8 +115,8 @@ struct GpuVisibilityLocalLightsConsts
 	Vec3 m_gridVolumeSize;
 	F32 m_padding3;
 
-	Vec3 m_cellCounts;
-	F32 m_padding4;
+	UVec3 m_cellCounts;
+	U32 m_cellCount;
 };
 
 ANKI_END_NAMESPACE

+ 1 - 1
AnKi/Shaders/VisibilityAndCollisionFunctions.hlsl

@@ -52,7 +52,7 @@ Bool testRayTriangle(Vec3 rayOrigin, Vec3 rayDir, Vec3 v0, Vec3 v1, Vec3 v2, Boo
 }
 
 /// Return true if to AABBs overlap.
-Bool testAabbAabb(Vec3 aMin, Vec3 aMax, Vec3 bMin, Vec3 bMax)
+Bool aabbAabbOverlap(Vec3 aMin, Vec3 aMax, Vec3 bMin, Vec3 bMax)
 {
 	return all(aMin < bMax) && all(bMin < aMax);
 }