浏览代码

Add a more wide bilateral filter in SSGI

Panagiotis Christopoulos Charitos 5 年之前
父节点
当前提交
efe57af0fc
共有 3 个文件被更改,包括 132 次插入14 次删除
  1. 79 0
      shaders/BilateralFilter.glsl
  2. 41 12
      shaders/SsgiDenoise.ankiprog
  3. 12 2
      src/anki/renderer/Ssgi.cpp

+ 79 - 0
shaders/BilateralFilter.glsl

@@ -0,0 +1,79 @@
+// Copyright (C) 2009-2020, Panagiotis Christopoulos Charitos and contributors.
+// All rights reserved.
+// Code licensed under the BSD License.
+// http://www.anki3d.org/LICENSE
+
+#pragma once
+
+#include <shaders/Common.glsl>
+
+struct BilateralSample
+{
+	F32 m_depth;
+	Vec3 m_position;
+	Vec3 m_normal;
+	F32 m_roughness;
+};
+
+struct BilateralConfig
+{
+	F32 m_depthWeight;
+	F32 m_normalWeight;
+	F32 m_planeWeight;
+	F32 m_roughnessWeight;
+};
+
+// https://cs.dartmouth.edu/~wjarosz/publications/mara17towards.html
+F32 calculateBilateralWeight(BilateralSample center, BilateralSample tap, BilateralConfig config)
+{
+	F32 depthWeight = 1.0;
+	F32 normalWeight = 1.0;
+	F32 planeWeight = 1.0;
+	F32 glossyWeight = 1.0;
+
+	if(config.m_depthWeight > 0.0)
+	{
+#if 0
+		depthWeight = max(0.0, 1.0 - abs(tap.m_depth - center.m_depth) * config.m_depthWeight);
+#else
+		const F32 diff = abs(tap.m_depth - center.m_depth);
+		depthWeight = sqrt(1.0 / (EPSILON + diff)) * config.m_depthWeight;
+#endif
+	}
+
+	if(config.m_normalWeight > 0.0)
+	{
+		F32 normalCloseness = dot(tap.m_normal, center.m_normal);
+		normalCloseness = normalCloseness * normalCloseness;
+		normalCloseness = normalCloseness * normalCloseness;
+
+		const F32 normalError = (1.0 - normalCloseness);
+		normalWeight = max((1.0 - normalError * config.m_normalWeight), 0.0);
+	}
+
+	if(config.m_planeWeight > 0.0)
+	{
+		const F32 lowDistanceThreshold2 = 0.001;
+
+		// Change in position in camera space
+		Vec3 dq = center.m_position - tap.m_position;
+
+		// How far away is this point from the original sample in camera space? (Max value is unbounded)
+		const F32 distance2 = dot(dq, dq);
+
+		// How far off the expected plane (on the perpendicular) is this point? Max value is unbounded.
+		const F32 planeError = max(abs(dot(dq, tap.m_normal)), abs(dot(dq, center.m_normal)));
+
+		planeWeight = (distance2 < lowDistanceThreshold2)
+						  ? 1.0
+						  : pow(max(0.0, 1.0 - 2.0 * config.m_planeWeight * planeError / sqrt(distance2)), 2.0);
+	}
+
+	if(config.m_roughnessWeight > 0.0)
+	{
+		const F32 gDiff = abs(tap.m_roughness - center.m_roughness) * 10.0;
+		glossyWeight = max(0.0, 1.0 - (gDiff * config.m_roughnessWeight));
+	}
+
+	return depthWeight * normalWeight * planeWeight * glossyWeight;
+}

+ 41 - 12
shaders/SsgiDenoise.ankiprog

@@ -11,7 +11,8 @@ ANKI_SPECIALIZATION_CONSTANT_UVEC2(IN_TEXTURE_SIZE, 0, UVec2(1));
 
 #pragma anki start comp
 
-#include <shaders/Common.glsl>
+#include <shaders/BilateralFilter.glsl>
+#include <shaders/Pack.glsl>
 
 #if SAMPLE_COUNT < 3
 #	error See file
@@ -23,13 +24,19 @@ layout(local_size_x = WORKGROUP_SIZE.x, local_size_y = WORKGROUP_SIZE.y, local_s
 layout(set = 0, binding = 0) uniform sampler u_linearAnyClampSampler;
 layout(set = 0, binding = 1) uniform texture2D u_inTex;
 layout(set = 0, binding = 2) uniform texture2D u_depthTex;
-layout(set = 0, binding = 3) writeonly uniform image2D u_outImg;
+layout(set = 0, binding = 3) uniform texture2D u_gbuffer2Tex;
+layout(set = 0, binding = 4) writeonly uniform image2D u_outImg;
 
-F32 computeDepthWeight(F32 refDepth, F32 depth)
+layout(std140, push_constant, row_major) uniform b_pc
 {
-	const F32 diff = abs(refDepth - depth);
-	const F32 weight = sqrt(1.0 / (EPSILON + diff));
-	return weight;
+	Mat4 u_invViewProjMat;
+};
+
+Vec3 unproject(Vec2 ndc, F32 depth)
+{
+	const Vec4 worldPos4 = u_invViewProjMat * Vec4(ndc, depth, 1.0);
+	const Vec3 worldPos = worldPos4.xyz / worldPos4.w;
+	return worldPos;
 }
 
 F32 readDepth(Vec2 uv)
@@ -37,10 +44,28 @@ F32 readDepth(Vec2 uv)
 	return textureLod(u_depthTex, u_linearAnyClampSampler, uv, 0.0).r;
 }
 
-void sampleTex(Vec2 inUv, Vec2 depthUv, F32 refDepth, inout Vec3 col, inout F32 weight)
+Vec3 readNormal(Vec2 uv)
+{
+	return readNormalFromGBuffer(u_gbuffer2Tex, u_linearAnyClampSampler, uv);
+}
+
+void sampleTex(Vec2 colorUv, Vec2 fullUv, BilateralSample ref, inout Vec3 col, inout F32 weight)
 {
-	const Vec3 color = textureLod(u_inTex, u_linearAnyClampSampler, inUv, 0.0).rgb;
-	const F32 w = computeDepthWeight(refDepth, readDepth(depthUv));
+	const Vec3 color = textureLod(u_inTex, u_linearAnyClampSampler, colorUv, 0.0).rgb;
+
+	BilateralSample crnt;
+	crnt.m_depth = readDepth(fullUv);
+	crnt.m_position = unproject(UV_TO_NDC(fullUv), crnt.m_depth);
+	crnt.m_normal = readNormal(fullUv);
+
+	BilateralConfig config;
+	const Vec3 weights = normalize(Vec3(0.0, 1.0, 1.0));
+	config.m_depthWeight = weights.x;
+	config.m_normalWeight = weights.y;
+	config.m_planeWeight = weights.z;
+	config.m_roughnessWeight = 0.0;
+
+	const F32 w = calculateBilateralWeight(crnt, ref, config);
 	col += color * w;
 	weight += w;
 }
@@ -71,9 +96,13 @@ void main()
 
 	// Reference
 	Vec3 color = textureLod(u_inTex, u_linearAnyClampSampler, inUv, 0.0).rgb;
-	const F32 refDepth = readDepth(depthUv);
 	F32 weight = 1.0;
 
+	BilateralSample ref;
+	ref.m_depth = readDepth(depthUv);
+	ref.m_position = unproject(UV_TO_NDC(depthUv), ref.m_depth);
+	ref.m_normal = readNormal(depthUv);
+
 #if ORIENTATION == 1
 #	define X_OR_Y x
 #else
@@ -87,8 +116,8 @@ void main()
 
 	ANKI_UNROLL for(U32 i = 0u; i < (SAMPLE_COUNT - 1u) / 2u; ++i)
 	{
-		sampleTex(inUv + inUvOffset, depthUv + depthUvOffset, refDepth, color, weight);
-		sampleTex(inUv - inUvOffset, depthUv - depthUvOffset, refDepth, color, weight);
+		sampleTex(inUv + inUvOffset, depthUv + depthUvOffset, ref, color, weight);
+		sampleTex(inUv - inUvOffset, depthUv - depthUvOffset, ref, color, weight);
 
 		inUvOffset.X_OR_Y += IN_TEXEL_SIZE.X_OR_Y;
 		depthUvOffset.X_OR_Y += 2.0 * DEPTH_TEXEL_SIZE.X_OR_Y;

+ 12 - 2
src/anki/renderer/Ssgi.cpp

@@ -171,6 +171,7 @@ void Ssgi::populateRenderGraph(RenderingContext& ctx)
 		rpass.newDependency({m_runCtx.m_intermediateRts[READ], TextureUsageBit::SAMPLED_COMPUTE});
 		rpass.newDependency({m_runCtx.m_intermediateRts[WRITE], TextureUsageBit::IMAGE_COMPUTE_WRITE});
 		rpass.newDependency({m_r->getGBuffer().getDepthRt(), TextureUsageBit::SAMPLED_COMPUTE});
+		rpass.newDependency({m_r->getGBuffer().getColorRt(2), TextureUsageBit::SAMPLED_COMPUTE});
 
 		rpass.setWork(
 			[](RenderPassWorkContext& rgraphCtx) { static_cast<Ssgi*>(rgraphCtx.m_userData)->runHBlur(rgraphCtx); },
@@ -185,6 +186,7 @@ void Ssgi::populateRenderGraph(RenderingContext& ctx)
 		rpass.newDependency({m_runCtx.m_intermediateRts[WRITE], TextureUsageBit::SAMPLED_COMPUTE});
 		rpass.newDependency({m_runCtx.m_finalRt, TextureUsageBit::IMAGE_COMPUTE_WRITE});
 		rpass.newDependency({m_r->getGBuffer().getDepthRt(), TextureUsageBit::SAMPLED_COMPUTE});
+		rpass.newDependency({m_r->getGBuffer().getColorRt(2), TextureUsageBit::SAMPLED_COMPUTE});
 
 		rpass.setWork(
 			[](RenderPassWorkContext& rgraphCtx) {
@@ -239,8 +241,12 @@ void Ssgi::runVBlur(RenderPassWorkContext& rgraphCtx)
 	cmdb->bindSampler(0, 0, m_r->getSamplers().m_trilinearClamp);
 	rgraphCtx.bindColorTexture(0, 1, m_runCtx.m_intermediateRts[WRITE]);
 	rgraphCtx.bindTexture(0, 2, m_r->getGBuffer().getDepthRt(), TextureSubresourceInfo(DepthStencilAspectBit::DEPTH));
+	rgraphCtx.bindColorTexture(0, 3, m_r->getGBuffer().getColorRt(2));
 
-	rgraphCtx.bindImage(0, 3, m_runCtx.m_intermediateRts[READ], TextureSubresourceInfo());
+	rgraphCtx.bindImage(0, 4, m_runCtx.m_intermediateRts[READ], TextureSubresourceInfo());
+
+	const Mat4 mat = m_runCtx.m_ctx->m_matrices.m_viewProjectionJitter.getInverse();
+	cmdb->setPushConstants(&mat, sizeof(mat));
 
 	dispatchPPCompute(cmdb, 8, 8, m_r->getWidth() / 2, m_r->getHeight() / 2);
 }
@@ -253,8 +259,12 @@ void Ssgi::runHBlur(RenderPassWorkContext& rgraphCtx)
 	cmdb->bindSampler(0, 0, m_r->getSamplers().m_trilinearClamp);
 	rgraphCtx.bindColorTexture(0, 1, m_runCtx.m_intermediateRts[READ]);
 	rgraphCtx.bindTexture(0, 2, m_r->getGBuffer().getDepthRt(), TextureSubresourceInfo(DepthStencilAspectBit::DEPTH));
+	rgraphCtx.bindColorTexture(0, 3, m_r->getGBuffer().getColorRt(2));
+
+	rgraphCtx.bindImage(0, 4, m_runCtx.m_intermediateRts[WRITE], TextureSubresourceInfo());
 
-	rgraphCtx.bindImage(0, 3, m_runCtx.m_intermediateRts[WRITE], TextureSubresourceInfo());
+	const Mat4 mat = m_runCtx.m_ctx->m_matrices.m_viewProjectionJitter.getInverse();
+	cmdb->setPushConstants(&mat, sizeof(mat));
 
 	dispatchPPCompute(cmdb, 8, 8, m_r->getWidth() / 2, m_r->getHeight() / 2);
 }