Browse Source

Optimize SSAO for mobile

Panagiotis Christopoulos Charitos 2 years ago
parent
commit
9395f9c64e

+ 2 - 2
AnKi/Shaders/Include/MiscRendererTypes.h

@@ -127,7 +127,7 @@ struct VolumetricLightingConstants
 // SSAO
 // SSAO
 struct SsaoConstants
 struct SsaoConstants
 {
 {
-	F32 m_radius; ///< In meters.
+	RF32 m_radius; ///< In meters.
 	U32 m_sampleCount;
 	U32 m_sampleCount;
 	Vec2 m_viewportSizef;
 	Vec2 m_viewportSizef;
 
 
@@ -139,7 +139,7 @@ struct SsaoConstants
 	F32 m_projectionMat23;
 	F32 m_projectionMat23;
 
 
 	Vec2 m_prevJitterUv;
 	Vec2 m_prevJitterUv;
-	F32 m_ssaoPower;
+	RF32 m_ssaoPower;
 	U32 m_frameCount;
 	U32 m_frameCount;
 
 
 	Mat3x4 m_viewMat;
 	Mat3x4 m_viewMat;

+ 48 - 48
AnKi/Shaders/Ssao.ankiprog

@@ -28,7 +28,7 @@
 
 
 [[vk::binding(5)]] Texture2D<RVec4> g_historyTex;
 [[vk::binding(5)]] Texture2D<RVec4> g_historyTex;
 [[vk::binding(6)]] Texture2D<Vec4> g_motionVectorsTex;
 [[vk::binding(6)]] Texture2D<Vec4> g_motionVectorsTex;
-[[vk::binding(7)]] Texture2D<Vec4> g_historyLengthTex;
+[[vk::binding(7)]] Texture2D<RVec4> g_historyLengthTex;
 
 
 #	if defined(ANKI_COMPUTE_SHADER)
 #	if defined(ANKI_COMPUTE_SHADER)
 [[vk::binding(8)]] RWTexture2D<RVec4> g_outUav;
 [[vk::binding(8)]] RWTexture2D<RVec4> g_outUav;
@@ -49,7 +49,7 @@ Vec4 project(Vec4 p)
 	return projectPerspective(p, g_consts.m_projectionMat00, g_consts.m_projectionMat11, g_consts.m_projectionMat22, g_consts.m_projectionMat23);
 	return projectPerspective(p, g_consts.m_projectionMat00, g_consts.m_projectionMat11, g_consts.m_projectionMat22, g_consts.m_projectionMat23);
 }
 }
 
 
-F32 computeFalloff(F32 len)
+RF32 computeFalloff(RF32 len)
 {
 {
 	return sqrt(1.0f - min(1.0f, len / g_consts.m_radius));
 	return sqrt(1.0f - min(1.0f, len / g_consts.m_radius));
 }
 }
@@ -69,82 +69,82 @@ RF32 main([[vk::location(0)]] Vec2 uv : TEXCOORD, Vec4 svPosition : SV_POSITION)
 
 
 	const Vec2 ndc = uvToNdc(uv);
 	const Vec2 ndc = uvToNdc(uv);
 	const Vec3 Pc = unproject(ndc);
 	const Vec3 Pc = unproject(ndc);
-	const Vec3 V = normalize(-Pc); // View vector
+	const RVec3 V = normalize(-Pc); // View vector
 
 
 	// Get noise
 	// Get noise
 #	if 0
 #	if 0
 	Vec2 noiseTexSize;
 	Vec2 noiseTexSize;
 	g_noiseTex.GetDimensions(noiseTexSize.x, noiseTexSize.y);
 	g_noiseTex.GetDimensions(noiseTexSize.x, noiseTexSize.y);
-	const Vec2 noiseUv = Vec2(g_consts.m_viewportSizef) / noiseTexSize * uv;
-	const Vec2 noise2 = animateBlueNoise(g_noiseTex.SampleLevel(g_trilinearRepeatSampler, noiseUv, 0.0).xyz, g_consts.m_frameCount).yx;
+	const RVec2 noiseUv = Vec2(g_consts.m_viewportSizef) / noiseTexSize * uv;
+	const RVec2 noise2 = animateBlueNoise(g_noiseTex.SampleLevel(g_trilinearRepeatSampler, noiseUv, 0.0).xyz, g_consts.m_frameCount).yx;
 #	else
 #	else
-	const Vec2 noise2 = spatioTemporalNoise(svDispatchThreadId, g_consts.m_frameCount);
+	const RVec2 noise2 = spatioTemporalNoise(svDispatchThreadId, g_consts.m_frameCount);
 #	endif
 #	endif
 
 
 	// Rand slice direction
 	// Rand slice direction
-	const F32 randAng = noise2.x * kPi;
+	const RF32 randAng = noise2.x * kPi;
 #	if 0
 #	if 0
-	const F32 aspect = g_consts.m_viewportSizef.x / g_consts.m_viewportSizef.y;
-	const Vec2 dir2d = normalize(Vec2(cos(randAng), sin(randAng)) * Vec2(1.0f, aspect));
+	const RF32 aspect = g_consts.m_viewportSizef.x / g_consts.m_viewportSizef.y;
+	const RVec2 dir2d = normalize(Vec2(cos(randAng), sin(randAng)) * Vec2(1.0f, aspect));
 #	else
 #	else
-	const Vec2 dir2d = Vec2(cos(randAng), sin(randAng));
+	const RVec2 dir2d = Vec2(cos(randAng), sin(randAng));
 #	endif
 #	endif
 
 
 	// Project the view normal to the slice
 	// Project the view normal to the slice
 	const Vec3 worldNormal = unpackNormalFromGBuffer(g_gbufferRt2.SampleLevel(g_linearAnyClampSampler, uv, 0.0));
 	const Vec3 worldNormal = unpackNormalFromGBuffer(g_gbufferRt2.SampleLevel(g_linearAnyClampSampler, uv, 0.0));
-	const Vec3 viewNormal = mul(g_consts.m_viewMat, Vec4(worldNormal, 0.0));
+	const RVec3 viewNormal = mul(g_consts.m_viewMat, Vec4(worldNormal, 0.0));
 
 
-	const Vec3 directionVec = Vec3(dir2d, 0.0f);
-	const Vec3 orthoDirectionVec = directionVec - (dot(directionVec, V) * V);
-	const Vec3 axisVec = normalize(cross(orthoDirectionVec, V));
-	const Vec3 projectedNormalVec = viewNormal - axisVec * dot(viewNormal, axisVec);
-	const F32 signNorm = (F32)sign(dot(orthoDirectionVec, projectedNormalVec));
-	const F32 projectedNormalVecLength = length(projectedNormalVec);
-	const F32 cosNorm = saturate(dot(projectedNormalVec, V) / projectedNormalVecLength);
-	const F32 n = -signNorm * fastAcos(cosNorm);
+	const RVec3 directionVec = RVec3(dir2d, 0.0f);
+	const RVec3 orthoDirectionVec = directionVec - (dot(directionVec, V) * V);
+	const RVec3 axisVec = normalize(cross(orthoDirectionVec, V));
+	const RVec3 projectedNormalVec = viewNormal - axisVec * dot(viewNormal, axisVec);
+	const RF32 signNorm = (F32)sign(dot(orthoDirectionVec, projectedNormalVec));
+	const RF32 projectedNormalVecLength = length(projectedNormalVec);
+	const RF32 cosNorm = saturate(dot(projectedNormalVec, V) / projectedNormalVecLength);
+	const RF32 n = -signNorm * fastAcos(cosNorm);
 
 
 	// Find the projected radius
 	// Find the projected radius
 	const Vec3 sphereLimit = Pc + Vec3(g_consts.m_radius, 0.0, 0.0);
 	const Vec3 sphereLimit = Pc + Vec3(g_consts.m_radius, 0.0, 0.0);
 	const Vec4 projSphereLimit = project(Vec4(sphereLimit, 1.0));
 	const Vec4 projSphereLimit = project(Vec4(sphereLimit, 1.0));
 	const Vec2 projSphereLimit2 = projSphereLimit.xy / projSphereLimit.w;
 	const Vec2 projSphereLimit2 = projSphereLimit.xy / projSphereLimit.w;
-	const F32 projRadius = length(projSphereLimit2 - ndc);
+	const RF32 projRadius = length(projSphereLimit2 - ndc);
 
 
 	// Compute the inner integral (Slide 54)
 	// Compute the inner integral (Slide 54)
 	const U32 stepCount = max(1u, g_consts.m_sampleCount / 2u);
 	const U32 stepCount = max(1u, g_consts.m_sampleCount / 2u);
 
 
-	const F32 lowHorizonCos1 = cos(n - kPi / 2.0f);
-	const F32 lowHorizonCos2 = cos(n + kPi / 2.0f);
+	const RF32 lowHorizonCos1 = cos(n - kPi / 2.0f);
+	const RF32 lowHorizonCos2 = cos(n + kPi / 2.0f);
 
 
-	F32 cosH1 = lowHorizonCos1;
-	F32 cosH2 = lowHorizonCos2;
+	RF32 cosH1 = lowHorizonCos1;
+	RF32 cosH2 = lowHorizonCos2;
 
 
 	for(U32 i = 0u; i < stepCount; ++i)
 	for(U32 i = 0u; i < stepCount; ++i)
 	{
 	{
-		const F32 stepBaseNoise = F32(i * stepCount) * 0.6180339887498948482;
-		const F32 stepNoise = frac(noise2.y + stepBaseNoise);
-		F32 s = (i + stepNoise) / F32(stepCount);
+		const RF32 stepBaseNoise = RF32(i * stepCount) * 0.6180339887498948482;
+		const RF32 stepNoise = frac(noise2.y + stepBaseNoise);
+		RF32 s = (i + stepNoise) / RF32(stepCount);
 		s *= s;
 		s *= s;
 		const Vec2 sampleOffset = dir2d * projRadius * s;
 		const Vec2 sampleOffset = dir2d * projRadius * s;
 
 
 		// h1
 		// h1
 		const Vec3 Ps = unproject(ndc + sampleOffset);
 		const Vec3 Ps = unproject(ndc + sampleOffset);
 		const Vec3 Ds = Ps - Pc;
 		const Vec3 Ds = Ps - Pc;
-		const F32 DsLen = length(Ds);
+		const RF32 DsLen = length(Ds);
 		cosH1 = max(cosH1, lerp(lowHorizonCos1, dot(V, Ds) / DsLen, computeFalloff(DsLen)));
 		cosH1 = max(cosH1, lerp(lowHorizonCos1, dot(V, Ds) / DsLen, computeFalloff(DsLen)));
 
 
 		// h2
 		// h2
 		const Vec3 Pt = unproject(ndc - sampleOffset);
 		const Vec3 Pt = unproject(ndc - sampleOffset);
 		const Vec3 Dt = Pt - Pc;
 		const Vec3 Dt = Pt - Pc;
-		const F32 DtLen = length(Dt);
+		const RF32 DtLen = length(Dt);
 		cosH2 = max(cosH2, lerp(lowHorizonCos2, dot(V, Dt) / DtLen, computeFalloff(DtLen)));
 		cosH2 = max(cosH2, lerp(lowHorizonCos2, dot(V, Dt) / DtLen, computeFalloff(DtLen)));
 	}
 	}
 
 
 	// Compute the h1 and h2
 	// Compute the h1 and h2
-	const F32 h1 = n + max(-fastAcos(cosH1) - n, -kPi / 2);
-	const F32 h2 = n + min(fastAcos(cosH2) - n, kPi / 2);
+	const RF32 h1 = n + max(-fastAcos(cosH1) - n, -kPi / 2);
+	const RF32 h2 = n + min(fastAcos(cosH2) - n, kPi / 2);
 
 
 	// Compute the final value (Slide 61)
 	// Compute the final value (Slide 61)
-	F32 Vd = -cos(2.0f * h1 - n) + cos(n) + 2.0f * h1 * sin(n);
+	RF32 Vd = -cos(2.0f * h1 - n) + cos(n) + 2.0f * h1 * sin(n);
 	Vd += -cos(2.0f * h2 - n) + cos(n) + 2.0f * h2 * sin(n);
 	Vd += -cos(2.0f * h2 - n) + cos(n) + 2.0f * h2 * sin(n);
 	Vd *= 0.25;
 	Vd *= 0.25;
 	Vd *= projectedNormalVecLength;
 	Vd *= projectedNormalVecLength;
@@ -158,22 +158,22 @@ RF32 main([[vk::location(0)]] Vec2 uv : TEXCOORD, Vec4 svPosition : SV_POSITION)
 
 
 		// History length creates black trails so it doesn't work correctly
 		// History length creates black trails so it doesn't work correctly
 #	if 0
 #	if 0
-		const Vec4 historyLengths = g_historyLengthTex.GatherRed(g_linearAnyClampSampler, uv + g_consts.m_prevJitterUv);
-		const F32 historyLength = max4(historyLengths);
+		const RVec4 historyLengths = g_historyLengthTex.GatherRed(g_linearAnyClampSampler, uv + g_consts.m_prevJitterUv);
+		const RF32 historyLength = max4(historyLengths);
 #	else
 #	else
-		const F32 historyLength = (any(historyUv < 0.0f) || any(historyUv > 1.0f)) ? 0.0f : 1.0f;
+		const RF32 historyLength = (any(historyUv < 0.0f) || any(historyUv > 1.0f)) ? 0.0f : 1.0f;
 #	endif
 #	endif
 
 
-		const F32 lowestBlendFactor = 0.1f;
-		const F32 maxHistoryLength = 16.0f;
-		const F32 stableFrames = 4.0f;
-		const F32 lerpVal = min(1.0f, (historyLength * maxHistoryLength - 1.0f) / stableFrames);
-		const F32 blendFactor = lerp(1.0f, lowestBlendFactor, lerpVal);
+		const RF32 lowestBlendFactor = 0.1f;
+		const RF32 maxHistoryLength = 16.0f;
+		const RF32 stableFrames = 4.0f;
+		const RF32 lerpVal = min(1.0f, (historyLength * maxHistoryLength - 1.0f) / stableFrames);
+		const RF32 blendFactor = lerp(1.0f, lowestBlendFactor, lerpVal);
 
 
 		// Blend with history
 		// Blend with history
 		if(blendFactor < 1.0)
 		if(blendFactor < 1.0)
 		{
 		{
-			const F32 history = g_historyTex.SampleLevel(g_linearAnyClampSampler, historyUv, 0.0f).r;
+			const RF32 history = g_historyTex.SampleLevel(g_linearAnyClampSampler, historyUv, 0.0f).r;
 			Vd = lerp(history, Vd, blendFactor);
 			Vd = lerp(history, Vd, blendFactor);
 		}
 		}
 	}
 	}
@@ -194,11 +194,11 @@ RF32 main([[vk::location(0)]] Vec2 uv : TEXCOORD, Vec4 svPosition : SV_POSITION)
 #	include <AnKi/Shaders/BilateralFilter.hlsl>
 #	include <AnKi/Shaders/BilateralFilter.hlsl>
 
 
 [[vk::binding(0)]] SamplerState g_linearAnyClampSampler;
 [[vk::binding(0)]] SamplerState g_linearAnyClampSampler;
-[[vk::binding(1)]] Texture2D<Vec4> g_inTex;
+[[vk::binding(1)]] Texture2D<RVec4> g_inTex;
 [[vk::binding(2)]] Texture2D<Vec4> g_depthTex;
 [[vk::binding(2)]] Texture2D<Vec4> g_depthTex;
 
 
 #	if defined(ANKI_COMPUTE_SHADER)
 #	if defined(ANKI_COMPUTE_SHADER)
-[[vk::binding(3)]] RWTexture2D<Vec4> g_outImg;
+[[vk::binding(3)]] RWTexture2D<RVec4> g_outImg;
 #	endif
 #	endif
 
 
 F32 readDepth(Vec2 uv)
 F32 readDepth(Vec2 uv)
@@ -206,10 +206,10 @@ F32 readDepth(Vec2 uv)
 	return g_depthTex.SampleLevel(g_linearAnyClampSampler, uv, 0.0).x;
 	return g_depthTex.SampleLevel(g_linearAnyClampSampler, uv, 0.0).x;
 }
 }
 
 
-void sampleTex(Vec2 uv, F32 refDepth, inout F32 col, inout F32 weight)
+void sampleTex(Vec2 uv, F32 refDepth, inout RF32 col, inout RF32 weight)
 {
 {
-	const F32 color = g_inTex.SampleLevel(g_linearAnyClampSampler, uv, 0.0).x;
-	const F32 w = calculateBilateralWeightDepth(refDepth, readDepth(uv), 1.0f);
+	const RF32 color = g_inTex.SampleLevel(g_linearAnyClampSampler, uv, 0.0).x;
+	const RF32 w = calculateBilateralWeightDepth(refDepth, readDepth(uv), 1.0f);
 	col += color * w;
 	col += color * w;
 	weight += w;
 	weight += w;
 }
 }
@@ -232,9 +232,9 @@ F32 main([[vk::location(0)]] Vec2 uv : TEXCOORD) : SV_TARGET0
 	const Vec2 texelSize = 1.0 / Vec2(textureSize);
 	const Vec2 texelSize = 1.0 / Vec2(textureSize);
 
 
 	// Sample
 	// Sample
-	F32 color = g_inTex.SampleLevel(g_linearAnyClampSampler, uv, 0.0).r;
+	RF32 color = g_inTex.SampleLevel(g_linearAnyClampSampler, uv, 0.0).r;
 	const F32 refDepth = readDepth(uv);
 	const F32 refDepth = readDepth(uv);
-	F32 weight = 1.0;
+	RF32 weight = 1.0;
 
 
 #	if defined(ANKI_TECHNIQUE_SsaoDenoiseHorizontal)
 #	if defined(ANKI_TECHNIQUE_SsaoDenoiseHorizontal)
 #		define X_OR_Y x
 #		define X_OR_Y x

+ 1 - 1
AnKi/Shaders/VisualizeRenderTarget.ankiprog

@@ -15,6 +15,6 @@
 
 
 Vec3 main(Vec2 uv : TEXCOORD) : SV_TARGET0
 Vec3 main(Vec2 uv : TEXCOORD) : SV_TARGET0
 {
 {
-	return g_inTex.SampleLevel(g_nearestAnyClampSampler, uv, 0.0).rgb;
+	return g_inTex.SampleLevel(g_nearestAnyClampSampler, uv, 0.0).rrr;
 }
 }
 #pragma anki technique_end frag
 #pragma anki technique_end frag