taa_resolve.glsl 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. ///////////////////////////////////////////////////////////////////////////////////
  2. // Copyright(c) 2016-2022 Panos Karabelas
  3. //
  4. // Permission is hereby granted, free of charge, to any person obtaining a copy
  5. // of this software and associated documentation files (the "Software"), to deal
  6. // in the Software without restriction, including without limitation the rights
  7. // to use, copy, modify, merge, publish, distribute, sublicense, and / or sell
  8. // copies of the Software, and to permit persons to whom the Software is furnished
  9. // to do so, subject to the following conditions :
  10. //
  11. // The above copyright notice and this permission notice shall be included in
  12. // all copies or substantial portions of the Software.
  13. //
  14. // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  15. // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
  16. // FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE AUTHORS OR
  17. // COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
  18. // IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  19. // CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
  20. ///////////////////////////////////////////////////////////////////////////////////
  21. // File changes (yyyy-mm-dd)
  22. // 2022-05-06: Panos Karabelas: first commit
  23. // 2020-12-05: Joan Fons: convert to Vulkan and Godot
  24. ///////////////////////////////////////////////////////////////////////////////////
  25. #[compute]
  26. #version 450
  27. #VERSION_DEFINES
  28. // Based on Spartan Engine's TAA implementation (without TAA upscale).
  29. // <https://github.com/PanosK92/SpartanEngine/blob/a8338d0609b85dc32f3732a5c27fb4463816a3b9/Data/shaders/temporal_antialiasing.hlsl>
  30. #ifndef MOLTENVK_USED
  31. #define USE_SUBGROUPS
  32. #endif // MOLTENVK_USED
  33. #define GROUP_SIZE 8
  34. #define FLT_MIN 0.00000001
  35. #define FLT_MAX 32767.0
  36. #define RPC_9 0.11111111111
  37. #define RPC_16 0.0625
  38. #ifdef USE_SUBGROUPS
  39. layout(local_size_x = GROUP_SIZE, local_size_y = GROUP_SIZE, local_size_z = 1) in;
  40. #endif
  41. layout(rgba16f, set = 0, binding = 0) uniform restrict readonly image2D color_buffer;
  42. layout(set = 0, binding = 1) uniform sampler2D depth_buffer;
  43. layout(rg16f, set = 0, binding = 2) uniform restrict readonly image2D velocity_buffer;
  44. layout(rg16f, set = 0, binding = 3) uniform restrict readonly image2D last_velocity_buffer;
  45. layout(set = 0, binding = 4) uniform sampler2D history_buffer;
  46. layout(rgba16f, set = 0, binding = 5) uniform restrict writeonly image2D output_buffer;
  47. layout(push_constant, std430) uniform Params {
  48. vec2 resolution;
  49. float disocclusion_threshold; // 0.1 / max(params.resolution.x, params.resolution.y
  50. float disocclusion_scale;
  51. }
  52. params;
  53. const ivec2 kOffsets3x3[9] = {
  54. ivec2(-1, -1),
  55. ivec2(0, -1),
  56. ivec2(1, -1),
  57. ivec2(-1, 0),
  58. ivec2(0, 0),
  59. ivec2(1, 0),
  60. ivec2(-1, 1),
  61. ivec2(0, 1),
  62. ivec2(1, 1),
  63. };
  64. /*------------------------------------------------------------------------------
  65. THREAD GROUP SHARED MEMORY (LDS)
  66. ------------------------------------------------------------------------------*/
  67. const int kBorderSize = 1;
  68. const int kGroupSize = GROUP_SIZE;
  69. const int kTileDimension = kGroupSize + kBorderSize * 2;
  70. const int kTileDimension2 = kTileDimension * kTileDimension;
  71. vec3 reinhard(vec3 hdr) {
  72. return hdr / (hdr + 1.0);
  73. }
  74. vec3 reinhard_inverse(vec3 sdr) {
  75. return sdr / (1.0 - sdr);
  76. }
  77. float get_depth(ivec2 thread_id) {
  78. return texelFetch(depth_buffer, thread_id, 0).r;
  79. }
  80. #ifdef USE_SUBGROUPS
  81. shared vec3 tile_color[kTileDimension][kTileDimension];
  82. shared float tile_depth[kTileDimension][kTileDimension];
  83. vec3 load_color(uvec2 group_thread_id) {
  84. group_thread_id += kBorderSize;
  85. return tile_color[group_thread_id.x][group_thread_id.y];
  86. }
  87. void store_color(uvec2 group_thread_id, vec3 color) {
  88. tile_color[group_thread_id.x][group_thread_id.y] = color;
  89. }
  90. float load_depth(uvec2 group_thread_id) {
  91. group_thread_id += kBorderSize;
  92. return tile_depth[group_thread_id.x][group_thread_id.y];
  93. }
  94. void store_depth(uvec2 group_thread_id, float depth) {
  95. tile_depth[group_thread_id.x][group_thread_id.y] = depth;
  96. }
  97. void store_color_depth(uvec2 group_thread_id, ivec2 thread_id) {
  98. // out of bounds clamp
  99. thread_id = clamp(thread_id, ivec2(0, 0), ivec2(params.resolution) - ivec2(1, 1));
  100. store_color(group_thread_id, imageLoad(color_buffer, thread_id).rgb);
  101. store_depth(group_thread_id, get_depth(thread_id));
  102. }
  103. void populate_group_shared_memory(uvec2 group_id, uint group_index) {
  104. // Populate group shared memory
  105. ivec2 group_top_left = ivec2(group_id) * kGroupSize - kBorderSize;
  106. if (group_index < (kTileDimension2 >> 2)) {
  107. ivec2 group_thread_id_1 = ivec2(group_index % kTileDimension, group_index / kTileDimension);
  108. ivec2 group_thread_id_2 = ivec2((group_index + (kTileDimension2 >> 2)) % kTileDimension, (group_index + (kTileDimension2 >> 2)) / kTileDimension);
  109. ivec2 group_thread_id_3 = ivec2((group_index + (kTileDimension2 >> 1)) % kTileDimension, (group_index + (kTileDimension2 >> 1)) / kTileDimension);
  110. ivec2 group_thread_id_4 = ivec2((group_index + kTileDimension2 * 3 / 4) % kTileDimension, (group_index + kTileDimension2 * 3 / 4) / kTileDimension);
  111. store_color_depth(group_thread_id_1, group_top_left + group_thread_id_1);
  112. store_color_depth(group_thread_id_2, group_top_left + group_thread_id_2);
  113. store_color_depth(group_thread_id_3, group_top_left + group_thread_id_3);
  114. store_color_depth(group_thread_id_4, group_top_left + group_thread_id_4);
  115. }
  116. // Wait for group threads to load store data.
  117. groupMemoryBarrier();
  118. barrier();
  119. }
  120. #else
  121. vec3 load_color(uvec2 screen_pos) {
  122. return imageLoad(color_buffer, ivec2(screen_pos)).rgb;
  123. }
  124. float load_depth(uvec2 screen_pos) {
  125. return get_depth(ivec2(screen_pos));
  126. }
  127. #endif
  128. /*------------------------------------------------------------------------------
  129. VELOCITY
  130. ------------------------------------------------------------------------------*/
  131. void depth_test_min(uvec2 pos, inout float min_depth, inout uvec2 min_pos) {
  132. float depth = load_depth(pos);
  133. if (depth < min_depth) {
  134. min_depth = depth;
  135. min_pos = pos;
  136. }
  137. }
  138. // Returns velocity with closest depth (3x3 neighborhood)
  139. void get_closest_pixel_velocity_3x3(in uvec2 group_pos, uvec2 group_top_left, out vec2 velocity) {
  140. float min_depth = 1.0;
  141. uvec2 min_pos = group_pos;
  142. depth_test_min(group_pos + kOffsets3x3[0], min_depth, min_pos);
  143. depth_test_min(group_pos + kOffsets3x3[1], min_depth, min_pos);
  144. depth_test_min(group_pos + kOffsets3x3[2], min_depth, min_pos);
  145. depth_test_min(group_pos + kOffsets3x3[3], min_depth, min_pos);
  146. depth_test_min(group_pos + kOffsets3x3[4], min_depth, min_pos);
  147. depth_test_min(group_pos + kOffsets3x3[5], min_depth, min_pos);
  148. depth_test_min(group_pos + kOffsets3x3[6], min_depth, min_pos);
  149. depth_test_min(group_pos + kOffsets3x3[7], min_depth, min_pos);
  150. depth_test_min(group_pos + kOffsets3x3[8], min_depth, min_pos);
  151. // Velocity out
  152. velocity = imageLoad(velocity_buffer, ivec2(group_top_left + min_pos)).xy;
  153. }
  154. /*------------------------------------------------------------------------------
  155. HISTORY SAMPLING
  156. ------------------------------------------------------------------------------*/
  157. vec3 sample_catmull_rom_9(sampler2D stex, vec2 uv, vec2 resolution) {
  158. // Source: https://gist.github.com/TheRealMJP/c83b8c0f46b63f3a88a5986f4fa982b1
  159. // License: https://gist.github.com/TheRealMJP/bc503b0b87b643d3505d41eab8b332ae
  160. // We're going to sample a 4x4 grid of texels surrounding the target UV coordinate. We'll do this by rounding
  161. // down the sample location to get the exact center of our "starting" texel. The starting texel will be at
  162. // location [1, 1] in the grid, where [0, 0] is the top left corner.
  163. vec2 sample_pos = uv * resolution;
  164. vec2 texPos1 = floor(sample_pos - 0.5f) + 0.5f;
  165. // Compute the fractional offset from our starting texel to our original sample location, which we'll
  166. // feed into the Catmull-Rom spline function to get our filter weights.
  167. vec2 f = sample_pos - texPos1;
  168. // Compute the Catmull-Rom weights using the fractional offset that we calculated earlier.
  169. // These equations are pre-expanded based on our knowledge of where the texels will be located,
  170. // which lets us avoid having to evaluate a piece-wise function.
  171. vec2 w0 = f * (-0.5f + f * (1.0f - 0.5f * f));
  172. vec2 w1 = 1.0f + f * f * (-2.5f + 1.5f * f);
  173. vec2 w2 = f * (0.5f + f * (2.0f - 1.5f * f));
  174. vec2 w3 = f * f * (-0.5f + 0.5f * f);
  175. // Work out weighting factors and sampling offsets that will let us use bilinear filtering to
  176. // simultaneously evaluate the middle 2 samples from the 4x4 grid.
  177. vec2 w12 = w1 + w2;
  178. vec2 offset12 = w2 / (w1 + w2);
  179. // Compute the final UV coordinates we'll use for sampling the texture
  180. vec2 texPos0 = texPos1 - 1.0f;
  181. vec2 texPos3 = texPos1 + 2.0f;
  182. vec2 texPos12 = texPos1 + offset12;
  183. texPos0 /= resolution;
  184. texPos3 /= resolution;
  185. texPos12 /= resolution;
  186. vec3 result = vec3(0.0f, 0.0f, 0.0f);
  187. result += textureLod(stex, vec2(texPos0.x, texPos0.y), 0.0).xyz * w0.x * w0.y;
  188. result += textureLod(stex, vec2(texPos12.x, texPos0.y), 0.0).xyz * w12.x * w0.y;
  189. result += textureLod(stex, vec2(texPos3.x, texPos0.y), 0.0).xyz * w3.x * w0.y;
  190. result += textureLod(stex, vec2(texPos0.x, texPos12.y), 0.0).xyz * w0.x * w12.y;
  191. result += textureLod(stex, vec2(texPos12.x, texPos12.y), 0.0).xyz * w12.x * w12.y;
  192. result += textureLod(stex, vec2(texPos3.x, texPos12.y), 0.0).xyz * w3.x * w12.y;
  193. result += textureLod(stex, vec2(texPos0.x, texPos3.y), 0.0).xyz * w0.x * w3.y;
  194. result += textureLod(stex, vec2(texPos12.x, texPos3.y), 0.0).xyz * w12.x * w3.y;
  195. result += textureLod(stex, vec2(texPos3.x, texPos3.y), 0.0).xyz * w3.x * w3.y;
  196. return max(result, 0.0f);
  197. }
  198. /*------------------------------------------------------------------------------
  199. HISTORY CLIPPING
  200. ------------------------------------------------------------------------------*/
  201. // Based on "Temporal Reprojection Anti-Aliasing" - https://github.com/playdeadgames/temporal
  202. vec3 clip_aabb(vec3 aabb_min, vec3 aabb_max, vec3 p, vec3 q) {
  203. vec3 r = q - p;
  204. vec3 rmax = (aabb_max - p.xyz);
  205. vec3 rmin = (aabb_min - p.xyz);
  206. if (r.x > rmax.x + FLT_MIN) {
  207. r *= (rmax.x / r.x);
  208. }
  209. if (r.y > rmax.y + FLT_MIN) {
  210. r *= (rmax.y / r.y);
  211. }
  212. if (r.z > rmax.z + FLT_MIN) {
  213. r *= (rmax.z / r.z);
  214. }
  215. if (r.x < rmin.x - FLT_MIN) {
  216. r *= (rmin.x / r.x);
  217. }
  218. if (r.y < rmin.y - FLT_MIN) {
  219. r *= (rmin.y / r.y);
  220. }
  221. if (r.z < rmin.z - FLT_MIN) {
  222. r *= (rmin.z / r.z);
  223. }
  224. return p + r;
  225. }
  226. // Clip history to the neighbourhood of the current sample
  227. vec3 clip_history_3x3(uvec2 group_pos, vec3 color_history, vec2 velocity_closest) {
  228. // Sample a 3x3 neighbourhood
  229. vec3 s1 = load_color(group_pos + kOffsets3x3[0]);
  230. vec3 s2 = load_color(group_pos + kOffsets3x3[1]);
  231. vec3 s3 = load_color(group_pos + kOffsets3x3[2]);
  232. vec3 s4 = load_color(group_pos + kOffsets3x3[3]);
  233. vec3 s5 = load_color(group_pos + kOffsets3x3[4]);
  234. vec3 s6 = load_color(group_pos + kOffsets3x3[5]);
  235. vec3 s7 = load_color(group_pos + kOffsets3x3[6]);
  236. vec3 s8 = load_color(group_pos + kOffsets3x3[7]);
  237. vec3 s9 = load_color(group_pos + kOffsets3x3[8]);
  238. // Compute min and max (with an adaptive box size, which greatly reduces ghosting)
  239. vec3 color_avg = (s1 + s2 + s3 + s4 + s5 + s6 + s7 + s8 + s9) * RPC_9;
  240. vec3 color_avg2 = ((s1 * s1) + (s2 * s2) + (s3 * s3) + (s4 * s4) + (s5 * s5) + (s6 * s6) + (s7 * s7) + (s8 * s8) + (s9 * s9)) * RPC_9;
  241. float box_size = mix(0.0f, 2.5f, smoothstep(0.02f, 0.0f, length(velocity_closest)));
  242. vec3 dev = sqrt(abs(color_avg2 - (color_avg * color_avg))) * box_size;
  243. vec3 color_min = color_avg - dev;
  244. vec3 color_max = color_avg + dev;
  245. // Variance clipping
  246. vec3 color = clip_aabb(color_min, color_max, clamp(color_avg, color_min, color_max), color_history);
  247. // Clamp to prevent NaNs
  248. color = clamp(color, FLT_MIN, FLT_MAX);
  249. return color;
  250. }
  251. /*------------------------------------------------------------------------------
  252. TAA
  253. ------------------------------------------------------------------------------*/
  254. const vec3 lumCoeff = vec3(0.299f, 0.587f, 0.114f);
  255. float luminance(vec3 color) {
  256. return max(dot(color, lumCoeff), 0.0001f);
  257. }
  258. // This is "velocity disocclusion" as described by https://www.elopezr.com/temporal-aa-and-the-quest-for-the-holy-trail/.
  259. // We use texel space, so our scale and threshold differ.
  260. float get_factor_disocclusion(vec2 uv_reprojected, vec2 velocity) {
  261. vec2 velocity_previous = imageLoad(last_velocity_buffer, ivec2(uv_reprojected * params.resolution)).xy;
  262. vec2 velocity_texels = velocity * params.resolution;
  263. vec2 prev_velocity_texels = velocity_previous * params.resolution;
  264. float disocclusion = length(prev_velocity_texels - velocity_texels) - params.disocclusion_threshold;
  265. return clamp(disocclusion * params.disocclusion_scale, 0.0, 1.0);
  266. }
  267. vec3 temporal_antialiasing(uvec2 pos_group_top_left, uvec2 pos_group, uvec2 pos_screen, vec2 uv, sampler2D tex_history) {
  268. // Get the velocity of the current pixel
  269. vec2 velocity = imageLoad(velocity_buffer, ivec2(pos_screen)).xy;
  270. // Get reprojected uv
  271. vec2 uv_reprojected = uv + velocity;
  272. // Get input color
  273. vec3 color_input = load_color(pos_group);
  274. // Get history color (catmull-rom reduces a lot of the blurring that you get under motion)
  275. vec3 color_history = sample_catmull_rom_9(tex_history, uv_reprojected, params.resolution).rgb;
  276. // Clip history to the neighbourhood of the current sample (fixes a lot of the ghosting).
  277. vec2 velocity_closest = vec2(0.0); // This is best done by using the velocity with the closest depth.
  278. get_closest_pixel_velocity_3x3(pos_group, pos_group_top_left, velocity_closest);
  279. color_history = clip_history_3x3(pos_group, color_history, velocity_closest);
  280. // Compute blend factor
  281. float blend_factor = RPC_16; // We want to be able to accumulate as many jitter samples as we generated, that is, 16.
  282. {
  283. // If re-projected UV is out of screen, converge to current color immediately.
  284. float factor_screen = any(lessThan(uv_reprojected, vec2(0.0))) || any(greaterThan(uv_reprojected, vec2(1.0))) ? 1.0 : 0.0;
  285. // Increase blend factor when there is disocclusion (fixes a lot of the remaining ghosting).
  286. float factor_disocclusion = get_factor_disocclusion(uv_reprojected, velocity);
  287. // Add to the blend factor
  288. blend_factor = clamp(blend_factor + factor_screen + factor_disocclusion, 0.0, 1.0);
  289. }
  290. // Resolve
  291. vec3 color_resolved = vec3(0.0);
  292. {
  293. // Tonemap
  294. color_history = reinhard(color_history);
  295. color_input = reinhard(color_input);
  296. // Reduce flickering
  297. float lum_color = luminance(color_input);
  298. float lum_history = luminance(color_history);
  299. float diff = abs(lum_color - lum_history) / max(lum_color, max(lum_history, 1.001));
  300. diff = 1.0 - diff;
  301. diff = diff * diff;
  302. blend_factor = mix(0.0, blend_factor, diff);
  303. // Lerp/blend
  304. color_resolved = mix(color_history, color_input, blend_factor);
  305. // Inverse tonemap
  306. color_resolved = reinhard_inverse(color_resolved);
  307. }
  308. return color_resolved;
  309. }
  310. void main() {
  311. #ifdef USE_SUBGROUPS
  312. populate_group_shared_memory(gl_WorkGroupID.xy, gl_LocalInvocationIndex);
  313. #endif
  314. // Out of bounds check
  315. if (any(greaterThanEqual(vec2(gl_GlobalInvocationID.xy), params.resolution))) {
  316. return;
  317. }
  318. #ifdef USE_SUBGROUPS
  319. const uvec2 pos_group = gl_LocalInvocationID.xy;
  320. const uvec2 pos_group_top_left = gl_WorkGroupID.xy * kGroupSize - kBorderSize;
  321. #else
  322. const uvec2 pos_group = gl_GlobalInvocationID.xy;
  323. const uvec2 pos_group_top_left = uvec2(0, 0);
  324. #endif
  325. const uvec2 pos_screen = gl_GlobalInvocationID.xy;
  326. const vec2 uv = (gl_GlobalInvocationID.xy + 0.5f) / params.resolution;
  327. vec3 result = temporal_antialiasing(pos_group_top_left, pos_group, pos_screen, uv, history_buffer);
  328. imageStore(output_buffer, ivec2(gl_GlobalInvocationID.xy), vec4(result, 1.0));
  329. }