Przeglądaj źródła

SSR : Compute ray intersection with the camera frustum to get an end point and use a for loop instead of a while loop to avoid crash on some AMD GPU

TothBenoit 11 miesięcy temu
rodzic
commit
3144d7333a
1 zmienionych plików z 60 dodań i 38 usunięć
  1. 60 38
      hrt/prefab/rfx/SSR.hx

+ 60 - 38
hrt/prefab/rfx/SSR.hx

@@ -11,6 +11,7 @@ class SSRShader extends h3d.shader.ScreenShader {
 		@param var normalMap : Sampler2D;
 
 		@param var cameraView : Mat4;
+		@param var cameraInverseView : Mat4;
 		@param var cameraProj : Mat4;
 		@param var cameraInverseProj : Mat4;
 		@param var cameraPos: Vec3;
@@ -22,6 +23,8 @@ class SSRShader extends h3d.shader.ScreenShader {
 		@param var minCosAngle : Float;
 		@param var rayMarchingResolution : Float;
 
+		@param var frustum : Buffer<Vec4, 6>;
+
 		@const var batchSample : Bool;
 		@const var CHECK_ANGLE : Bool;
 
@@ -38,50 +41,69 @@ class SSRShader extends h3d.shader.ScreenShader {
 			return vpos / vpos.w;
 		};
 
-		function fromWposToVpos(wpos : Vec3 ) : Vec3 {
-			var vpos = vec4(wpos, 1.0) * cameraView;
-			return vpos.xyz / vpos.w;
+		function intersectViewRayWithFrustum( start : Vec3, dir : Vec3 ) : Vec3 {
+			var wStart = vec4(start, 1.0) * cameraInverseView;
+			wStart /= wStart.w;
+			var wDir = normalize(dir * cameraInverseView.mat3());
+
+			var minT = 1000000.0;
+			for ( i in 0...6 ) {
+				var plane = frustum[i];
+				var num = plane.w - ( plane.x * wStart.x + plane.y * wStart.y + plane.z * wStart.z );
+				var denom = plane.x * wDir.x + plane.y * wDir.y + plane.z * wDir.z;
+				var t = num / denom;
+				if ( denom != 0.0 && t > 0.0 )
+					minT = min( minT, t );
+			}
+
+			var wEnd = wStart.xyz + wDir * minT;
+			var vEnd = vec4(wEnd, 1.0) * cameraView;
+			return vEnd.xyz / vEnd.w;
 		}
 
 		function fragment() {
 			var normal = normalMap.get(calculatedUV).rgb;
-
 			if (normal.dot(normal) <= 0)
 				discard;
 
-			 var roughnessFactor = 1 - smoothstep(0.0, maxRoughness, roughnessMap.get(calculatedUV).g);
+			var roughnessFactor = 1 - smoothstep(0.0, maxRoughness, roughnessMap.get(calculatedUV).g);
 			if (roughnessFactor <= 0)
 				discard;
 
 			var positionFrom = getViewPos(calculatedUV);
 			var camDir = normalize(positionFrom.xyz);
-			var viewNormal = normalize(fromWposToVpos(normal + cameraPos));
+			var viewNormal = normalize( normal * cameraView.mat3() );
 			var reflectedRay = reflectedRay(camDir, viewNormal);
-			reflectedRay /= length(reflectedRay.xy);
+			reflectedRay = normalize(reflectedRay);
 
-			var startFrag = positionFrom * cameraProj;
-			startFrag.xyz /= startFrag.w;
-			startFrag.xy = screenToUv(startFrag.xy);
-			startFrag.xy *= texSize;
+			var positionTo = intersectViewRayWithFrustum(positionFrom.xyz, reflectedRay);
 
-			var fragDir = vec4(positionFrom.xyz + reflectedRay, 1.0) * cameraProj;
-			fragDir.xyz /= fragDir.w;
-			fragDir.xy = screenToUv(fragDir.xy);
-			fragDir.xy *= texSize;
-			fragDir.xy = normalize(fragDir.xy - startFrag.xy);
+			var startFrag = calculatedUV * texSize;
+			var roundStartFrag = roundEven(startFrag);
+
+			var endFrag = vec4(positionTo, 1.0) * cameraProj;
+			endFrag.xyz /= endFrag.w;
+			endFrag.xy = screenToUv(endFrag.xy);
+			endFrag.xy *= texSize;
+			var roundEndFrag = roundEven(endFrag);
+
+			if ( roundStartFrag.x == roundEndFrag.x && roundStartFrag.y == roundEndFrag.y )
+				discard;
 
 			var hit = 0;
-			var increment = fragDir.xy / saturate(rayMarchingResolution);
+			var ray = endFrag.xy - startFrag.xy;
+			var rayLength = length(ray);
+			var stepCount = ceil(rayLength * rayMarchingResolution);
+			var increment = ray / stepCount;
 			var frag = startFrag.xy + increment;
 			var uv = frag / texSize;
 
 			if (!batchSample) {
-				do {
-					var positionTo = getViewPos(uv);
-					var viewStepLength = distance(positionTo.xy, positionFrom.xy);
-					var viewDistance = positionFrom.z + reflectedRay.z * viewStepLength;
-					var depth = viewDistance - positionTo.z;
-
+				var iStepCount = int( stepCount );
+				for ( curStep in 0...iStepCount ) {
+					var curPos = getViewPos(uv);
+					var viewDistance = (positionFrom.z * positionTo.z) / mix(positionTo.z, positionFrom.z, float( curStep + 1 ) / stepCount );
+					var depth = viewDistance - curPos.z;
 					if ( depth >= 0.0 && depth < thickness && screenDepth < 1 ) {
 						hit = 1;
 						break;
@@ -89,24 +111,23 @@ class SSRShader extends h3d.shader.ScreenShader {
 
 					frag += increment;
 					uv = frag / texSize;
-
-				} while (uv.x >= 0.0 && uv.x < 1.0 && uv.y >= 0.0 && uv.y < 1.0);
-			}
-			else {
-				do {
+				}
+			} else {
+				var iStepCount = int( ceil( stepCount / 4 ) );
+				for ( curStep in 0...iStepCount ) {
 					var results : Array<Bool, 4> = [false, false, false, false];
+
 					@unroll
 					for ( i in 0...4 ) {
-						var positionTo = getViewPos(uv);
-						var viewStepLength = distance(positionTo.xy, positionFrom.xy);
-						var viewDistance = positionFrom.z + reflectedRay.z * viewStepLength;
-						var depth = viewDistance - positionTo.z;
+						var curPos = getViewPos(uv);
+						var viewDistance = (positionFrom.z * positionTo.z) / mix(positionTo.z, positionFrom.z, float( curStep * 4 + i + 1 ) / stepCount );
+						var depth = viewDistance - curPos.z;
 						results[i] = depth >= 0.0 && depth < thickness && screenDepth < 1;
 						frag += increment;
 						uv = frag / texSize;
 					}
 
-					  if (results[0] || results[1] || results[2] || results[3]) {
+					if (results[0] || results[1] || results[2] || results[3]) {
 						hit = 1;
 						for ( j in 0...4 ) {
 							if (results[j]) {
@@ -116,11 +137,10 @@ class SSRShader extends h3d.shader.ScreenShader {
 						}
 						break;
 					}
-				} while (uv.x >= 0.0 && uv.x < 1.0 && uv.y >= 0.0 && uv.y < 1.0);
+				}
 			}
 
-
-			if (hit != 1)
+			if ( hit != 1 )
 				discard;
 
 			if ( CHECK_ANGLE ) {
@@ -179,18 +199,20 @@ class SSR extends RendererFX {
 			if ( minAngle == 0 )
 				ssrShader.CHECK_ANGLE = false;
 			ssrShader.minCosAngle = Math.cos(hxd.Math.degToRad(minAngle));
-			ssrShader.rayMarchingResolution = rayMarchingResolution;
 			var resRescale = 1.0;
 			if ( !support4K )
 				resRescale = hxd.Math.max(1.0, hxd.Math.max(ssrShader.texSize.x / 2560, ssrShader.texSize.y / 1440));
-			ssrShader.rayMarchingResolution /= resRescale;
+			ssrShader.rayMarchingResolution = hxd.Math.clamp(rayMarchingResolution / resRescale);
 			ssrShader.batchSample = batchSample;
 
 			ssrShader.cameraView = r.ctx.camera.mcam;
+			ssrShader.cameraInverseView = r.ctx.camera.getInverseView();
 			ssrShader.cameraProj = r.ctx.camera.mproj;
 			ssrShader.cameraInverseProj = r.ctx.camera.getInverseProj();
 			ssrShader.cameraPos = r.ctx.camera.pos;
 
+			ssrShader.frustum = r.ctx.getCameraFrustumBuffer();
+
 			ssr = r.allocTarget("ssr", false, textureSize / resRescale, hdrMap.format);
 			ssr.clear(0, 0);
 			r.ctx.engine.pushTarget(ssr);