Browse Source

Add probes in the indirect diffuse

Panagiotis Christopoulos Charitos 4 years ago
parent
commit
3bec6bf425

+ 1 - 0
AnKi/Renderer/ConfigDefs.h

@@ -21,6 +21,7 @@ ANKI_CONFIG_OPTION(r_ssrDepthLod, 2, 0, 1000)
 
 ANKI_CONFIG_OPTION(r_ssgiMaxSteps, 32, 1, 2048)
 ANKI_CONFIG_OPTION(r_ssgiDepthLod, 2, 0, 1000)
+ANKI_CONFIG_OPTION(r_ssgiStepIncrement, 32, 1, 512)
 
 ANKI_CONFIG_OPTION(r_shadowMappingTileResolution, 128, 16, 2048)
 ANKI_CONFIG_OPTION(r_shadowMappingTileCountPerRowOrColumn, 16, 1, 256)

+ 26 - 25
AnKi/Renderer/IndirectDiffuse.cpp

@@ -9,8 +9,9 @@
 #include <AnKi/Renderer/GBuffer.h>
 #include <AnKi/Renderer/DownscaleBlur.h>
 #include <AnKi/Renderer/MotionVectors.h>
+#include <AnKi/Renderer/GlobalIllumination.h>
 #include <AnKi/Core/ConfigSet.h>
-#include <AnKi/Shaders/Include/SsgiTypes.h>
+#include <AnKi/Shaders/Include/IndirectDiffuseTypes.h>
 
 namespace anki
 {
@@ -36,14 +37,14 @@ Error IndirectDiffuse::initInternal(const ConfigSet& cfg)
 
 	m_main.m_maxSteps = cfg.getNumberU32("r_ssgiMaxSteps");
 	m_main.m_depthLod = min(cfg.getNumberU32("r_ssgiDepthLod"), m_r->getDepthDownscale().getMipmapCount() - 1);
-	m_main.m_firstStepPixels = 32;
+	m_main.m_stepIncrement = cfg.getNumberU32("r_ssgiStepIncrement");
 
 	ANKI_CHECK(getResourceManager().loadResource("EngineAssets/BlueNoise_Rgba8_16x16.png", m_main.m_noiseImage));
 
-	// Init SSGI
+	// Init SSGI+probes pass
 	{
-		m_main.m_rtDescr =
-			m_r->create2DRenderTargetDescription(size.x(), size.y(), Format::R16G16B16A16_SFLOAT, "IndirectDiffuse");
+		m_main.m_rtDescr = m_r->create2DRenderTargetDescription(size.x(), size.y(), Format::B10G11R11_UFLOAT_PACK32,
+																"IndirectDiffuse");
 		m_main.m_rtDescr.bake();
 
 		ANKI_CHECK(getResourceManager().loadResource("Shaders/IndirectDiffuse.ankiprog", m_main.m_prog));
@@ -80,30 +81,30 @@ void IndirectDiffuse::populateRenderGraph(RenderingContext& ctx)
 			CommandBufferPtr& cmdb = rgraphCtx.m_commandBuffer;
 			cmdb->bindShaderProgram(m_main.m_grProg);
 
-			rgraphCtx.bindImage(0, 0, m_runCtx.m_ssgiRtHandle, TextureSubresourceInfo());
+			const ClusteredShadingContext& binning = ctx.m_clusteredShading;
+			bindUniforms(cmdb, 0, 0, binning.m_clusteredShadingUniformsToken);
+			m_r->getGlobalIllumination().bindVolumeTextures(ctx, rgraphCtx, 0, 1);
+			bindUniforms(cmdb, 0, 2, binning.m_globalIlluminationProbesToken);
+			bindStorage(cmdb, 0, 3, binning.m_clustersToken);
 
-			// Bind uniforms
-			SsgiUniforms* unis = allocateAndBindUniforms<SsgiUniforms*>(sizeof(SsgiUniforms), cmdb, 0, 1);
-			unis->m_depthBufferSize =
-				UVec2(m_r->getInternalResolution().x(), m_r->getInternalResolution().y()) >> (m_main.m_depthLod + 1);
-			unis->m_framebufferSize = m_r->getInternalResolution() / 2u;
-			unis->m_invProjMat = ctx.m_matrices.m_projectionJitter.getInverse();
-			unis->m_projMat = ctx.m_matrices.m_projectionJitter;
-			unis->m_prevViewProjMatMulInvViewProjMat =
-				ctx.m_prevMatrices.m_viewProjection * ctx.m_matrices.m_viewProjectionJitter.getInverse();
-			unis->m_normalMat = Mat3x4(Vec3(0.0f), ctx.m_matrices.m_view.getRotationPart());
-			unis->m_frameCount = m_r->getFrameCount() & MAX_U32;
-			unis->m_maxSteps = m_main.m_maxSteps;
-			unis->m_firstStepPixels = m_main.m_firstStepPixels;
-
-			// Bind the rest
-			cmdb->bindSampler(0, 2, m_r->getSamplers().m_trilinearClamp);
-			rgraphCtx.bindColorTexture(0, 3, m_r->getGBuffer().getColorRt(2));
+			rgraphCtx.bindImage(0, 4, m_runCtx.m_ssgiRtHandle, TextureSubresourceInfo());
+
+			cmdb->bindSampler(0, 5, m_r->getSamplers().m_trilinearClamp);
+			rgraphCtx.bindColorTexture(0, 6, m_r->getGBuffer().getColorRt(2));
 
 			TextureSubresourceInfo hizSubresource;
 			hizSubresource.m_firstMipmap = m_main.m_depthLod;
-			rgraphCtx.bindTexture(0, 4, m_r->getDepthDownscale().getHiZRt(), hizSubresource);
-			rgraphCtx.bindColorTexture(0, 5, m_r->getDownscaleBlur().getRt());
+			rgraphCtx.bindTexture(0, 7, m_r->getDepthDownscale().getHiZRt(), hizSubresource);
+			rgraphCtx.bindColorTexture(0, 8, m_r->getDownscaleBlur().getRt());
+
+			// Bind uniforms
+			IndirectDiffuseUniforms unis;
+			unis.m_depthBufferSize = m_r->getInternalResolution() >> (m_main.m_depthLod + 1);
+			unis.m_maxSteps = m_main.m_maxSteps;
+			unis.m_stepIncrement = m_main.m_stepIncrement;
+			unis.m_viewportSize = m_r->getInternalResolution() / 2u;
+			unis.m_viewportSizef = Vec2(unis.m_viewportSize);
+			cmdb->setPushConstants(&unis, sizeof(unis));
 
 			// Dispatch
 			dispatchPPCompute(cmdb, 8, 8, m_r->getInternalResolution().x() / 2, m_r->getInternalResolution().y() / 2);

+ 1 - 3
AnKi/Renderer/IndirectDiffuse.h

@@ -47,7 +47,7 @@ private:
 		RenderTargetDescription m_rtDescr;
 		ImageResourcePtr m_noiseImage;
 		U32 m_maxSteps = 32;
-		U32 m_firstStepPixels = 16;
+		U32 m_stepIncrement = 16;
 		U32 m_depthLod = 0;
 	} m_main;
 
@@ -58,8 +58,6 @@ private:
 	} m_runCtx;
 
 	ANKI_USE_RESULT Error initInternal(const ConfigSet& cfg);
-
-	void run(const RenderingContext& ctx, RenderPassWorkContext& rgraphCtx);
 };
 /// @}
 

+ 2 - 7
AnKi/Renderer/MotionVectors.cpp

@@ -78,13 +78,8 @@ void MotionVectors::run(const RenderingContext& ctx, RenderPassWorkContext& rgra
 		Mat4 m_prevViewProjectionInvMat;
 	} pc;
 
-	// This reprojection is not enterely correct. It first unprojects the current depth to world space, all fine here.
-	// Then it projects it to the previous frame but assumes identity jitter. Then it jitters using current frame's
-	// jitter. The last multiplication I don't get but it works perfectly for sampling the TAA history buffer.
-	pc.m_reprojectionMat = ctx.m_matrices.m_jitter * ctx.m_prevMatrices.m_viewProjection
-						   * ctx.m_matrices.m_viewProjectionJitter.getInverse();
-
-	pc.m_prevViewProjectionInvMat = ctx.m_prevMatrices.m_viewProjectionJitter.getInverse();
+	pc.m_reprojectionMat = ctx.m_matrices.m_reprojection;
+	pc.m_prevViewProjectionInvMat = ctx.m_prevMatrices.m_invertedProjectionJitter;
 	cmdb->setPushConstants(&pc, sizeof(pc));
 
 	dispatchPPCompute(cmdb, 8, 8, m_r->getInternalResolution().x(), m_r->getInternalResolution().y());

+ 6 - 2
AnKi/Renderer/Renderer.cpp

@@ -301,10 +301,13 @@ void Renderer::initJitteredMats()
 
 Error Renderer::populateRenderGraph(RenderingContext& ctx)
 {
+	ctx.m_prevMatrices = m_prevMatrices;
+
 	ctx.m_matrices.m_cameraTransform = ctx.m_renderQueue->m_cameraTransform;
 	ctx.m_matrices.m_view = ctx.m_renderQueue->m_viewMatrix;
 	ctx.m_matrices.m_projection = ctx.m_renderQueue->m_projectionMatrix;
 	ctx.m_matrices.m_viewProjection = ctx.m_renderQueue->m_viewProjectionMatrix;
+	ctx.m_matrices.m_viewRotation = ctx.m_renderQueue->m_viewMatrix.getRotationPart();
 
 	ctx.m_matrices.m_jitter = m_jitteredMats8x[m_frameCount & (m_jitteredMats8x.getSize() - 1)];
 	ctx.m_matrices.m_projectionJitter = ctx.m_matrices.m_jitter * ctx.m_matrices.m_projection;
@@ -314,9 +317,10 @@ Error Renderer::populateRenderGraph(RenderingContext& ctx)
 	ctx.m_matrices.m_invertedProjectionJitter = ctx.m_matrices.m_projectionJitter.getInverse();
 	ctx.m_matrices.m_invertedView = ctx.m_matrices.m_view.getInverse();
 
-	ctx.m_matrices.m_unprojectionParameters = ctx.m_matrices.m_projection.extractPerspectiveUnprojectionParams();
+	ctx.m_matrices.m_reprojection =
+		ctx.m_matrices.m_jitter * ctx.m_prevMatrices.m_viewProjection * ctx.m_matrices.m_invertedViewProjectionJitter;
 
-	ctx.m_prevMatrices = m_prevMatrices;
+	ctx.m_matrices.m_unprojectionParameters = ctx.m_matrices.m_projection.extractPerspectiveUnprojectionParams();
 
 	// Check if resources got loaded
 	if(m_prevLoadRequestCount != m_resources->getLoadingRequestCount()

+ 1 - 0
AnKi/ShaderCompiler/ShaderProgramParser.cpp

@@ -173,6 +173,7 @@ static const char* SHADER_HEADER = R"(#version 460 core
 #endif
 
 #define Mat3 mat3
+#define _ANKI_SIZEOF_mat3 36u
 
 #define Mat4 mat4
 #define _ANKI_SIZEOF_mat4 64u

+ 11 - 1
AnKi/Shaders/Include/ClusteredShadingTypes.h

@@ -163,6 +163,9 @@ struct CommonMatrices
 	Mat4 m_view ANKI_CPP_CODE(= Mat4::getIdentity());
 	Mat4 m_projection ANKI_CPP_CODE(= Mat4::getIdentity());
 	Mat4 m_viewProjection ANKI_CPP_CODE(= Mat4::getIdentity());
+	Mat3 m_viewRotation ANKI_CPP_CODE(= Mat3::getIdentity());
+
+	F32 m_padding[3u]; // Because of the alignment requirements of some of the following members (in C++)
 
 	Mat4 m_jitter ANKI_CPP_CODE(= Mat4::getIdentity());
 	Mat4 m_projectionJitter ANKI_CPP_CODE(= Mat4::getIdentity());
@@ -173,9 +176,16 @@ struct CommonMatrices
 	Mat4 m_invertedProjectionJitter ANKI_CPP_CODE(= Mat4::getIdentity()); ///< To unproject in view space.
 	Mat4 m_invertedView ANKI_CPP_CODE(= Mat4::getIdentity());
 
+	/// It's being used to reproject a clip space position of the current frame to the previous frame. Its value should
+	/// be m_jitter * m_prevFrame.m_viewProjection * m_invertedViewProjectionJitter. At first it unprojects the current
+	/// position to world space, all fine here. Then it projects to the previous frame as if the previous frame was
+	/// using the current frame's jitter matrix.
+	Mat4 m_reprojection ANKI_CPP_CODE(= Mat4::getIdentity());
+
 	Vec4 m_unprojectionParameters ANKI_CPP_CODE(= Vec4(0.0f)); ///< To unproject to view space. Jitter not considered.
 };
-const U32 _ANKI_SIZEOF_CommonMatrices = 11u * ANKI_SIZEOF(Mat4) + ANKI_SIZEOF(Vec4);
+const U32 _ANKI_SIZEOF_CommonMatrices =
+	12u * ANKI_SIZEOF(Mat4) + ANKI_SIZEOF(Vec4) + ANKI_SIZEOF(Mat3) + ANKI_SIZEOF(F32) * 3u;
 ANKI_SHADER_STATIC_ASSERT(sizeof(CommonMatrices) == _ANKI_SIZEOF_CommonMatrices);
 
 /// Common uniforms for light shading passes.

+ 78 - 23
AnKi/Shaders/IndirectDiffuse.ankiprog

@@ -3,7 +3,7 @@
 // Code licensed under the BSD License.
 // http://www.anki3d.org/LICENSE
 
-// Dose SSGI and GI probe sampling
+// Does SSGI and GI probe sampling
 
 #pragma anki start comp
 #include <AnKi/Shaders/SsRaymarching.glsl>
@@ -11,42 +11,49 @@
 #include <AnKi/Shaders/PackFunctions.glsl>
 #include <AnKi/Shaders/ImportanceSampling.glsl>
 #include <AnKi/Shaders/TonemappingFunctions.glsl>
-#include <AnKi/Shaders/Include/SsgiTypes.h>
+#include <AnKi/Shaders/Include/IndirectDiffuseTypes.h>
 
 const UVec2 WORKGROUP_SIZE = UVec2(8, 8);
 layout(local_size_x = WORKGROUP_SIZE.x, local_size_y = WORKGROUP_SIZE.y) in;
 
-layout(set = 0, binding = 0, rgba16f) uniform image2D out_img;
+#define CLUSTERED_SHADING_SET 0
+#define CLUSTERED_SHADING_UNIFORMS_BINDING 0
+#define CLUSTERED_SHADING_GI_BINDING 1
+#define CLUSTERED_SHADING_CLUSTERS_BINDING 3
+#include <AnKi/Shaders/ClusteredShadingCommon.glsl>
 
-layout(set = 0, binding = 1, row_major, std140) uniform b_unis
+layout(set = 0, binding = 4) writeonly uniform image2D out_img;
+
+layout(set = 0, binding = 5) uniform sampler u_trilinearClampSampler;
+layout(set = 0, binding = 6) uniform texture2D u_gbufferRt2;
+layout(set = 0, binding = 7) uniform texture2D u_depthRt;
+layout(set = 0, binding = 8) uniform texture2D u_lightBufferRt;
+
+layout(push_constant, std430) uniform b_pc
 {
-	SsgiUniforms u_unis;
+	IndirectDiffuseUniforms u_unis;
 };
 
-layout(set = 0, binding = 2) uniform sampler u_trilinearClampSampler;
-layout(set = 0, binding = 3) uniform texture2D u_gbufferRt2;
-layout(set = 0, binding = 4) uniform texture2D u_depthRt;
-layout(set = 0, binding = 5) uniform texture2D u_lightBufferRt;
-
 void main()
 {
-	const UVec2 fixedGlobalInvocationId = min(gl_GlobalInvocationID.xy, u_unis.m_framebufferSize);
-	const Vec2 uv = (Vec2(fixedGlobalInvocationId.xy) + 0.5) / Vec2(u_unis.m_framebufferSize);
+	const UVec2 fixedGlobalInvocationId = min(gl_GlobalInvocationID.xy, u_unis.m_viewportSize);
+	const Vec2 fragCoord = Vec2(fixedGlobalInvocationId.xy) + 0.5;
+	const Vec2 uv = fragCoord / u_unis.m_viewportSizef;
 	const Vec2 ndc = UV_TO_NDC(uv);
 
 	// Get normal
 	const Vec3 worldNormal = readNormalFromGBuffer(u_gbufferRt2, u_trilinearClampSampler, uv);
-	const Vec3 viewNormal = u_unis.m_normalMat * worldNormal;
+	const Vec3 viewNormal = u_clusteredShading.m_matrices.m_viewRotation * worldNormal;
 
 	// Get depth
 	const F32 depth = textureLod(u_depthRt, u_trilinearClampSampler, uv, 0.0).r;
 
 	// Compute view pos
-	const Vec4 viewPos4 = u_unis.m_invProjMat * Vec4(ndc, depth, 1.0);
+	const Vec4 viewPos4 = u_clusteredShading.m_matrices.m_invertedProjectionJitter * Vec4(ndc, depth, 1.0);
 	const Vec3 viewPos = viewPos4.xyz / viewPos4.w;
 
 	// Get a random point inside the hemisphere. Use hemisphereSampleCos to avoid perpendicular vecs to viewNormal
-	const UVec2 random = rand3DPCG16(UVec3(fixedGlobalInvocationId, u_unis.m_frameCount)).xy;
+	const UVec2 random = rand3DPCG16(UVec3(fixedGlobalInvocationId, u_clusteredShading.m_frame)).xy;
 	Vec2 randomCircle = hammersleyRandom16(0u, 0xFFFFu, random);
 	randomCircle.x *= 0.9; // Reduce the cone angle a bit to avoid self-collisions
 	randomCircle.x = pow(randomCircle.x, 4.0); // Get more samples closer to the normal
@@ -58,15 +65,16 @@ void main()
 	const U32 lod = 0u;
 	const F32 minStepf = 4.0;
 	const F32 noise = F32(random.x) * (1.0 / 65536.0);
-	raymarchGroundTruth(viewPos, randomHemisphere, uv, depth, u_unis.m_projMat, u_unis.m_maxSteps, u_depthRt,
-						u_trilinearClampSampler, F32(lod), u_unis.m_depthBufferSize, u_unis.m_firstStepPixels,
-						U32(mix(minStepf, F32(u_unis.m_firstStepPixels), noise)), hitPoint, hitAttenuation);
+	const U32 initialStep = U32(mix(minStepf, F32(u_unis.m_stepIncrement), noise));
+	raymarchGroundTruth(viewPos, randomHemisphere, uv, depth, u_clusteredShading.m_matrices.m_projectionJitter,
+						u_unis.m_maxSteps, u_depthRt, u_trilinearClampSampler, F32(lod), u_unis.m_depthBufferSize,
+						u_unis.m_stepIncrement, initialStep, hitPoint, hitAttenuation);
 
 	// Reject backfacing
 	ANKI_BRANCH if(hitAttenuation > 0.0)
 	{
-		const Vec3 hitNormal =
-			u_unis.m_normalMat * readNormalFromGBuffer(u_gbufferRt2, u_trilinearClampSampler, hitPoint.xy);
+		const Vec3 hitNormal = u_clusteredShading.m_matrices.m_viewRotation
+							   * readNormalFromGBuffer(u_gbufferRt2, u_trilinearClampSampler, hitPoint.xy);
 		F32 backFaceAttenuation;
 		rejectBackFaces(randomHemisphere, hitNormal, backFaceAttenuation);
 
@@ -78,7 +86,7 @@ void main()
 	ANKI_BRANCH if(hitAttenuation > 0.0)
 	{
 		// Reproject the UV because you are reading the previous frame
-		const Vec4 v4 = u_unis.m_prevViewProjMatMulInvViewProjMat * Vec4(UV_TO_NDC(hitPoint.xy), hitPoint.z, 1.0);
+		const Vec4 v4 = u_clusteredShading.m_matrices.m_reprojection * Vec4(UV_TO_NDC(hitPoint.xy), hitPoint.z, 1.0);
 		hitPoint.xy = NDC_TO_UV(v4.xy / v4.w);
 
 		// Read the light buffer
@@ -89,7 +97,7 @@ void main()
 #if 0
 		// Compute a new normal based on the new hit point
 		const F32 depth = textureLod(u_depthRt, u_trilinearClampSampler, hitPoint.xy, 0.0).r;
-		const Vec4 viewPos4 = u_unis.m_invProjMat * Vec4(UV_TO_NDC(hitPoint.xy), depth, 1.0);
+		const Vec4 viewPos4 = u_clusteredShading.m_matrices.m_invertedProjection * Vec4(UV_TO_NDC(hitPoint.xy), depth, 1.0);
 		const Vec3 hitViewPos = viewPos4.xyz / viewPos4.w;
 		const Vec3 newViewNormal = normalize(hitViewPos - viewPos);
 #else
@@ -103,10 +111,57 @@ void main()
 	}
 	else
 	{
-		outColor = Vec3(0.0, 0.0, 0.0);
+		// Fallback to probes
+
+		// Get the cluster
+		Cluster cluster = getClusterFragCoord(Vec3(fragCoord, depth));
+
+		// Get world position
+		const Vec4 worldPos4 = u_clusteredShading.m_matrices.m_invertedViewProjectionJitter * Vec4(ndc, depth, 1.0);
+		const Vec3 worldPos = worldPos4.xyz / worldPos4.w;
+
+		if(bitCount(cluster.m_giProbesMask) == 1)
+		{
+			// All subgroups point to the same probe and there is only one probe, do a fast path without blend weight
+
+			const GlobalIlluminationProbe probe = u_giProbes[findLSB2(cluster.m_giProbesMask)];
+
+			// Sample
+			outColor = sampleGlobalIllumination(worldPos, worldNormal, probe, u_globalIlluminationTextures,
+												u_trilinearClampSampler);
+		}
+		else
+		{
+			// Zero or more than one probes, do a slow path that blends them together
+
+			F32 totalBlendWeight = EPSILON;
+			outColor = Vec3(0.0);
+
+			// Loop probes
+			ANKI_LOOP while(cluster.m_giProbesMask != 0u)
+			{
+				const U32 idx = U32(findLSB2(cluster.m_giProbesMask));
+				cluster.m_giProbesMask &= ~(1u << idx);
+				const GlobalIlluminationProbe probe = u_giProbes[idx];
+
+				// Compute blend weight
+				const F32 blendWeight =
+					computeProbeBlendWeight(worldPos, probe.m_aabbMin, probe.m_aabbMax, probe.m_fadeDistance);
+				totalBlendWeight += blendWeight;
+
+				// Sample
+				const Vec3 c = sampleGlobalIllumination(worldPos, worldNormal, probe, u_globalIlluminationTextures,
+														u_trilinearClampSampler);
+				outColor += c * blendWeight;
+			}
+
+			// Normalize
+			outColor /= totalBlendWeight;
+		}
 	}
 
 	// Remove fireflies
+	if(false)
 	{
 		const F32 lum = computeLuminance(outColor) + 0.001;
 		const F32 averageLum = (subgroupAdd(lum) / F32(gl_SubgroupSize)) * 2.0;

+ 20 - 11
AnKi/Shaders/SsRaymarching.glsl

@@ -147,8 +147,15 @@ void raymarchGroundTruth(Vec3 rayOrigin, // Ray origin in view space
 						 Vec2 uv, // UV the ray starts
 						 F32 depthRef, // Depth the ray starts
 						 Mat4 projMat, // Projection matrix
-						 U32 maxIterations, texture2D depthTex, sampler depthSampler, F32 depthLod, UVec2 depthTexSize,
-						 U32 bigStep, U32 randInitialStep, out Vec3 hitPoint, out F32 attenuation)
+						 U32 maxSteps, // The max iterations of the base algorithm
+						 texture2D depthTex, // Depth tex
+						 sampler depthSampler, // Sampler for depthTex
+						 F32 depthLod, // LOD to pass to the textureLod
+						 UVec2 depthTexSize, // Size of the depthTex
+						 U32 initialStepIncrement, // Initial step increment
+						 U32 randInitialStep, // The initial step
+						 out Vec3 hitPoint, // Hit point in UV coordinates
+						 out F32 attenuation)
 {
 	attenuation = 0.0;
 
@@ -177,14 +184,14 @@ void raymarchGroundTruth(Vec3 rayOrigin, // Ray origin in view space
 	dir = normalize(dir);
 
 	// Compute step
-	I32 stepSkip = I32(bigStep);
-	I32 step = I32(randInitialStep);
+	I32 stepIncrement = I32(initialStepIncrement);
+	I32 crntStep = I32(randInitialStep);
 
-	// Iterate
+	// Search
 	Vec3 origin;
-	ANKI_LOOP while(maxIterations-- != 0u)
+	ANKI_LOOP while(maxSteps-- != 0u)
 	{
-		origin = start + dir * (F32(step) * stepSize);
+		origin = start + dir * (F32(crntStep) * stepSize);
 
 		// Check if it's out of the view
 		if(origin.x <= 0.0 || origin.y <= 0.0 || origin.x >= 1.0 || origin.y >= 1.0)
@@ -196,12 +203,14 @@ void raymarchGroundTruth(Vec3 rayOrigin, // Ray origin in view space
 		const Bool hit = origin.z - depth >= 0.0;
 		if(!hit)
 		{
-			step += stepSkip;
+			crntStep += stepIncrement;
 		}
-		else if(stepSkip > 1)
+		else if(stepIncrement > 1)
 		{
-			step = max(1, step - stepSkip + 1);
-			stepSkip = stepSkip / 2;
+			// There is a hit but the step increment is a bit high, need a more fine-grained search
+
+			crntStep = max(1, crntStep - stepIncrement + 1);
+			stepIncrement = stepIncrement / 2;
 		}
 		else
 		{