screen_space_reflection.glsl 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. #[compute]
  2. #version 450
  3. VERSION_DEFINES
  4. layout(local_size_x = 8, local_size_y = 8, local_size_z = 1) in;
  5. layout(rgba16f, set = 0, binding = 0) uniform restrict readonly image2D source_diffuse;
  6. layout(r32f, set = 0, binding = 1) uniform restrict readonly image2D source_depth;
  7. layout(rgba16f, set = 1, binding = 0) uniform restrict writeonly image2D ssr_image;
  8. #ifdef MODE_ROUGH
  9. layout(r8, set = 1, binding = 1) uniform restrict writeonly image2D blur_radius_image;
  10. #endif
  11. layout(rgba8, set = 2, binding = 0) uniform restrict readonly image2D source_normal_roughness;
  12. layout(set = 3, binding = 0) uniform sampler2D source_metallic;
  13. layout(push_constant, binding = 2, std430) uniform Params {
  14. vec4 proj_info;
  15. ivec2 screen_size;
  16. float camera_z_near;
  17. float camera_z_far;
  18. int num_steps;
  19. float depth_tolerance;
  20. float distance_fade;
  21. float curve_fade_in;
  22. bool orthogonal;
  23. float filter_mipmap_levels;
  24. bool use_half_res;
  25. uint metallic_mask;
  26. mat4 projection;
  27. }
  28. params;
  29. vec2 view_to_screen(vec3 view_pos, out float w) {
  30. vec4 projected = params.projection * vec4(view_pos, 1.0);
  31. projected.xyz /= projected.w;
  32. projected.xy = projected.xy * 0.5 + 0.5;
  33. w = projected.w;
  34. return projected.xy;
  35. }
  36. #define M_PI 3.14159265359
  37. vec3 reconstructCSPosition(vec2 S, float z) {
  38. if (params.orthogonal) {
  39. return vec3((S.xy * params.proj_info.xy + params.proj_info.zw), z);
  40. } else {
  41. return vec3((S.xy * params.proj_info.xy + params.proj_info.zw) * z, z);
  42. }
  43. }
  44. void main() {
  45. // Pixel being shaded
  46. ivec2 ssC = ivec2(gl_GlobalInvocationID.xy);
  47. if (any(greaterThanEqual(ssC, params.screen_size))) { //too large, do nothing
  48. return;
  49. }
  50. vec2 pixel_size = 1.0 / vec2(params.screen_size);
  51. vec2 uv = vec2(ssC) * pixel_size;
  52. uv += pixel_size * 0.5;
  53. float base_depth = imageLoad(source_depth, ssC).r;
  54. // World space point being shaded
  55. vec3 vertex = reconstructCSPosition(uv * vec2(params.screen_size), base_depth);
  56. vec4 normal_roughness = imageLoad(source_normal_roughness, ssC);
  57. vec3 normal = normal_roughness.xyz * 2.0 - 1.0;
  58. normal = normalize(normal);
  59. normal.y = -normal.y; //because this code reads flipped
  60. vec3 view_dir = normalize(vertex);
  61. vec3 ray_dir = normalize(reflect(view_dir, normal));
  62. if (dot(ray_dir, normal) < 0.001) {
  63. imageStore(ssr_image, ssC, vec4(0.0));
  64. return;
  65. }
  66. //ray_dir = normalize(view_dir - normal * dot(normal,view_dir) * 2.0);
  67. //ray_dir = normalize(vec3(1.0, 1.0, -1.0));
  68. ////////////////
  69. // make ray length and clip it against the near plane (don't want to trace beyond visible)
  70. float ray_len = (vertex.z + ray_dir.z * params.camera_z_far) > -params.camera_z_near ? (-params.camera_z_near - vertex.z) / ray_dir.z : params.camera_z_far;
  71. vec3 ray_end = vertex + ray_dir * ray_len;
  72. float w_begin;
  73. vec2 vp_line_begin = view_to_screen(vertex, w_begin);
  74. float w_end;
  75. vec2 vp_line_end = view_to_screen(ray_end, w_end);
  76. vec2 vp_line_dir = vp_line_end - vp_line_begin;
  77. // we need to interpolate w along the ray, to generate perspective correct reflections
  78. w_begin = 1.0 / w_begin;
  79. w_end = 1.0 / w_end;
  80. float z_begin = vertex.z * w_begin;
  81. float z_end = ray_end.z * w_end;
  82. vec2 line_begin = vp_line_begin / pixel_size;
  83. vec2 line_dir = vp_line_dir / pixel_size;
  84. float z_dir = z_end - z_begin;
  85. float w_dir = w_end - w_begin;
  86. // clip the line to the viewport edges
  87. float scale_max_x = min(1.0, 0.99 * (1.0 - vp_line_begin.x) / max(1e-5, vp_line_dir.x));
  88. float scale_max_y = min(1.0, 0.99 * (1.0 - vp_line_begin.y) / max(1e-5, vp_line_dir.y));
  89. float scale_min_x = min(1.0, 0.99 * vp_line_begin.x / max(1e-5, -vp_line_dir.x));
  90. float scale_min_y = min(1.0, 0.99 * vp_line_begin.y / max(1e-5, -vp_line_dir.y));
  91. float line_clip = min(scale_max_x, scale_max_y) * min(scale_min_x, scale_min_y);
  92. line_dir *= line_clip;
  93. z_dir *= line_clip;
  94. w_dir *= line_clip;
  95. // clip z and w advance to line advance
  96. vec2 line_advance = normalize(line_dir); // down to pixel
  97. float step_size = length(line_advance) / length(line_dir);
  98. float z_advance = z_dir * step_size; // adapt z advance to line advance
  99. float w_advance = w_dir * step_size; // adapt w advance to line advance
  100. // make line advance faster if direction is closer to pixel edges (this avoids sampling the same pixel twice)
  101. float advance_angle_adj = 1.0 / max(abs(line_advance.x), abs(line_advance.y));
  102. line_advance *= advance_angle_adj; // adapt z advance to line advance
  103. z_advance *= advance_angle_adj;
  104. w_advance *= advance_angle_adj;
  105. vec2 pos = line_begin;
  106. float z = z_begin;
  107. float w = w_begin;
  108. float z_from = z / w;
  109. float z_to = z_from;
  110. float depth;
  111. vec2 prev_pos = pos;
  112. bool found = false;
  113. float steps_taken = 0.0;
  114. for (int i = 0; i < params.num_steps; i++) {
  115. pos += line_advance;
  116. z += z_advance;
  117. w += w_advance;
  118. // convert to linear depth
  119. depth = imageLoad(source_depth, ivec2(pos - 0.5)).r;
  120. z_from = z_to;
  121. z_to = z / w;
  122. if (depth > z_to) {
  123. // if depth was surpassed
  124. if (depth <= max(z_to, z_from) + params.depth_tolerance && -depth < params.camera_z_far) {
  125. // check the depth tolerance and far clip
  126. // check that normal is valid
  127. found = true;
  128. }
  129. break;
  130. }
  131. steps_taken += 1.0;
  132. prev_pos = pos;
  133. }
  134. if (found) {
  135. float margin_blend = 1.0;
  136. vec2 margin = vec2((params.screen_size.x + params.screen_size.y) * 0.5 * 0.05); // make a uniform margin
  137. if (any(bvec4(lessThan(pos, -margin), greaterThan(pos, params.screen_size + margin)))) {
  138. // clip outside screen + margin
  139. imageStore(ssr_image, ssC, vec4(0.0));
  140. return;
  141. }
  142. {
  143. //blend fading out towards external margin
  144. vec2 margin_grad = mix(pos - params.screen_size, -pos, lessThan(pos, vec2(0.0)));
  145. margin_blend = 1.0 - smoothstep(0.0, margin.x, max(margin_grad.x, margin_grad.y));
  146. //margin_blend = 1.0;
  147. }
  148. vec2 final_pos;
  149. float grad;
  150. grad = steps_taken / float(params.num_steps);
  151. float initial_fade = params.curve_fade_in == 0.0 ? 1.0 : pow(clamp(grad, 0.0, 1.0), params.curve_fade_in);
  152. float fade = pow(clamp(1.0 - grad, 0.0, 1.0), params.distance_fade) * initial_fade;
  153. final_pos = pos;
  154. vec4 final_color;
  155. #ifdef MODE_ROUGH
  156. // if roughness is enabled, do screen space cone tracing
  157. float blur_radius = 0.0;
  158. float roughness = normal_roughness.w;
  159. if (roughness > 0.001) {
  160. float cone_angle = min(roughness, 0.999) * M_PI * 0.5;
  161. float cone_len = length(final_pos - line_begin);
  162. float op_len = 2.0 * tan(cone_angle) * cone_len; // opposite side of iso triangle
  163. {
  164. // fit to sphere inside cone (sphere ends at end of cone), something like this:
  165. // ___
  166. // \O/
  167. // V
  168. //
  169. // as it avoids bleeding from beyond the reflection as much as possible. As a plus
  170. // it also makes the rough reflection more elongated.
  171. float a = op_len;
  172. float h = cone_len;
  173. float a2 = a * a;
  174. float fh2 = 4.0f * h * h;
  175. blur_radius = (a * (sqrt(a2 + fh2) - a)) / (4.0f * h);
  176. }
  177. }
  178. final_color = imageLoad(source_diffuse, ivec2((final_pos - 0.5) * pixel_size));
  179. imageStore(blur_radius_image, ssC, vec4(blur_radius / 255.0)); //stored in r8
  180. #endif
  181. final_color = vec4(imageLoad(source_diffuse, ivec2(final_pos - 0.5)).rgb, fade * margin_blend);
  182. //change blend by metallic
  183. vec4 metallic_mask = unpackUnorm4x8(params.metallic_mask);
  184. final_color.a *= dot(metallic_mask, texelFetch(source_metallic, ssC << 1, 0));
  185. imageStore(ssr_image, ssC, final_color);
  186. } else {
  187. #ifdef MODE_ROUGH
  188. imageStore(blur_radius_image, ssC, vec4(0.0));
  189. #endif
  190. imageStore(ssr_image, ssC, vec4(0.0));
  191. }
  192. }