DepthUpsample.azsl 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Atom/Features/SrgSemantics.azsli>
  9. // --- Algorithm Overview ---
  10. //
  11. // This shader will upsample an input image using two input depth textues.
  12. // For simplicity, we call the source image 'sourceHalfRes', the low resolution depth 'depthHalfRes',
  13. // and the higher resolution depth 'depthFullRes' (which has the same resolution as the output image).
  14. // In order to reduce texture operations, each thread writes 2x2 pixels in the target output image.
  15. // This allows each thread to re-use results from texture gather operations between output pixels.
  16. //
  17. // To illustrate, consider the following texture (each number denotes a pixel)
  18. //
  19. // 00 10 02 03 04 05 06 07 08 09
  20. //
  21. // 10 11 12 13 14 15 16 17 18 19
  22. //
  23. // 20 21 22 23 24 25 26 27 28 29
  24. //
  25. // 30 31 32 33 34 35 36 37 38 39
  26. //
  27. // The downsampled version of this texture would have pixels at the following H* locations:
  28. //
  29. // 00 10 02 03 04 05 06 07 08 09
  30. // H0 H1 H2 H3 H4
  31. // 10 11 12 13 14 15 16 17 18 19
  32. //
  33. // 20 21 22 23 24 25 26 27 28 29
  34. // H5 H6 H7 H8 H9
  35. // 30 31 32 33 34 35 36 37 38 39
  36. //
  37. // To calculate the upsampled output pixel (11), we need four half-res depth values (H0, H1, H5, H6),
  38. // four half-res source values (H0, H1, H5, H6) and a full-res depth value (11).
  39. // Note that these same half-res depth and source values are also used to calculate output pixels (12, 21, 22)
  40. // Also note that pixels (H0, H1, H5, H6) can be fetched with a single gather, as can (11, 12, 21, 22)
  41. //
  42. // Thus, we can use a single thread to calculated and output upsampled pixels (11, 12, 21, 22)
  43. // For this, the thread would only need to perform three gathers (assuming source is a single chanel texture)
  44. // Gather 1: half-res depth (H0, H1, H5, H6)
  45. // Gather 2: half-res source (H0, H1, H5, H6)
  46. // Gather 3: full-res source (11, 12, 21, 22)
  47. //
  48. // Thus, we dispatch threads at the following T* locations
  49. //
  50. // T-00 T-01 T-02 T-03 T-04 T-05
  51. // 00 10 02 03 04 05 06 07 08 09
  52. // H0 H1 H2 H3 H4
  53. // 10 11 12 13 14 15 16 17 18 19
  54. // T-10 T-11 T-12 T-13 T-14 T-15
  55. // 20 21 22 23 24 25 26 27 28 29
  56. // H5 H6 H7 H8 H9
  57. // 30 31 32 33 34 35 36 37 38 39
  58. // T-20 T-21 T-22 T-23 T-24 T-25
  59. //
  60. // Continuing our example, here the thread T-11 would calculate the full res output pixels (11, 12, 21, 22)
  61. // Two things to note about the thread dispatch:
  62. //
  63. // 1) the width and height of the thread group are equal to the width and height of the half-res textures + 1.
  64. // The +1 is to have enough threads to output to row 3* and column *9.
  65. // Note that if the full-res texture has uneven dimensions this +1 is not necessary:
  66. //
  67. // T-00 T-01 T-02 T-03 T-04
  68. // 00 10 02 03 04 05 06 07 08
  69. // H0 H1 H2 H3 H4
  70. // 10 11 12 13 14 15 16 17 18
  71. // T-10 T-11 T-12 T-13 T-14
  72. // 20 21 22 23 24 25 26 27 28
  73. // H5 H6 H7 H8 H9
  74. //
  75. // 2) While the thread dispatch has similar dimensions to the half-res textures, the positions are shifted by (-1, -1),
  76. // i.e. they are shifted up and to the left by the width of a full-res pixel. This is so the threads can properly use
  77. // texture gather instructions on both the downsampled depth/source and the full-res depth (see above example for T-11)
  78. //
  79. #define THREADS 16
  80. ShaderResourceGroup PassSrg : SRG_PerPass
  81. {
  82. Texture2D<float> m_depthFullRes;
  83. Texture2D<float> m_depthHalfRes;
  84. Texture2D<float> m_sourceHalfRes;
  85. RWTexture2D<float> m_outputFullRes;
  86. // Must match the struct in DepthDownsamplePasses.cpp
  87. struct UpsampleConstants
  88. {
  89. // The size of a pixel in the input image relative to screenspace UV
  90. // Calculated by taking the inverse of the texture dimensions
  91. float2 m_inputPixelSize;
  92. // The size of a pixel in the output image relative to screenspace UV
  93. // Calculated by taking the inverse of the texture dimensions
  94. float2 m_outputPixelSize;
  95. };
  96. UpsampleConstants m_constants;
  97. Sampler PointSampler
  98. {
  99. MinFilter = Point;
  100. MagFilter = Point;
  101. MipFilter = Point;
  102. AddressU = Clamp;
  103. AddressV = Clamp;
  104. AddressW = Clamp;
  105. };
  106. }
  107. float GetDepthFactor(float depth1, float depth2)
  108. {
  109. const float epsilon = 0.00001f;
  110. float distance = abs(depth1 - depth2) + epsilon;
  111. float distanceSq = distance * distance;
  112. return 1.0f / distanceSq;
  113. }
  114. [numthreads(THREADS, THREADS, 1)]
  115. void MainCS(uint3 dispatch_id: SV_DispatchThreadID)
  116. {
  117. float2 position = dispatch_id.xy;
  118. // Gather half res depth and source values
  119. float2 halfResGatherUV = position * PassSrg::m_constants.m_inputPixelSize;
  120. float4 halfDepths = PassSrg::m_depthHalfRes.Gather(PassSrg::PointSampler, halfResGatherUV);
  121. float4 sourceValues = PassSrg::m_sourceHalfRes.Gather(PassSrg::PointSampler, halfResGatherUV);
  122. // Gather full res depth
  123. float2 fullResGatherUV = position * 2.0f * PassSrg::m_constants.m_outputPixelSize;
  124. float4 fullDepths = PassSrg::m_depthFullRes.Gather(PassSrg::PointSampler, fullResGatherUV);
  125. // Gather operation retrieves values with the following layout:
  126. //
  127. // W Z
  128. // X Y
  129. float4 outputValues = (float4)0.0f;
  130. // Calculate output W
  131. {
  132. float weight = 0.0f;
  133. float totalWeight = 0.0f;
  134. // 0.75 and 0.25 here is how far this full-res pixel is from the half-res pixels we are sampling
  135. // Consider to half-res pixels and two full-res pixels in between. The full-res pixel on the right
  136. // is 3x closer to the half-res pixel on the right than the half res pixel on the left, thus the
  137. // weights become 3/4 and 1/4 or 0.75 and 0.25
  138. weight = (0.75f * 0.75f) * GetDepthFactor(fullDepths.w, halfDepths.w);
  139. outputValues.w += sourceValues.w * weight;
  140. totalWeight += weight;
  141. weight = (0.25f * 0.75f) * GetDepthFactor(fullDepths.w, halfDepths.z);
  142. outputValues.w += sourceValues.z * weight;
  143. totalWeight += weight;
  144. weight = (0.75f * 0.25f) * GetDepthFactor(fullDepths.w, halfDepths.x);
  145. outputValues.w += sourceValues.x * weight;
  146. totalWeight += weight;
  147. weight = (0.25f * 0.25f) * GetDepthFactor(fullDepths.w, halfDepths.y);
  148. outputValues.w += sourceValues.y * weight;
  149. totalWeight += weight;
  150. outputValues.w /= totalWeight;
  151. }
  152. // Calculate output Z
  153. {
  154. float weight = 0.0f;
  155. float totalWeight = 0.0f;
  156. weight = (0.75f * 0.75f) * GetDepthFactor(fullDepths.z, halfDepths.z);
  157. outputValues.z += sourceValues.z * weight;
  158. totalWeight += weight;
  159. weight = (0.25f * 0.75f) * GetDepthFactor(fullDepths.z, halfDepths.w);
  160. outputValues.z += sourceValues.w * weight;
  161. totalWeight += weight;
  162. weight = (0.75f * 0.25f) * GetDepthFactor(fullDepths.z, halfDepths.y);
  163. outputValues.z += sourceValues.y * weight;
  164. totalWeight += weight;
  165. weight = (0.25f * 0.25f) * GetDepthFactor(fullDepths.z, halfDepths.x);
  166. outputValues.z += sourceValues.x * weight;
  167. totalWeight += weight;
  168. outputValues.z /= totalWeight;
  169. }
  170. // Calculate output Y
  171. {
  172. float weight = 0.0f;
  173. float totalWeight = 0.0f;
  174. weight = (0.75f * 0.75f) * GetDepthFactor(fullDepths.y, halfDepths.y);
  175. outputValues.y += sourceValues.y * weight;
  176. totalWeight += weight;
  177. weight = (0.25f * 0.75f) * GetDepthFactor(fullDepths.y, halfDepths.x);
  178. outputValues.y += sourceValues.x * weight;
  179. totalWeight += weight;
  180. weight = (0.75f * 0.25f) * GetDepthFactor(fullDepths.y, halfDepths.z);
  181. outputValues.y += sourceValues.z * weight;
  182. totalWeight += weight;
  183. weight = (0.25f * 0.25f) * GetDepthFactor(fullDepths.y, halfDepths.w);
  184. outputValues.y += sourceValues.w * weight;
  185. totalWeight += weight;
  186. outputValues.y /= totalWeight;
  187. }
  188. // Calculate output X
  189. {
  190. float weight = 0.0f;
  191. float totalWeight = 0.0f;
  192. weight = (0.75f * 0.75f) * GetDepthFactor(fullDepths.x, halfDepths.x);
  193. outputValues.x += sourceValues.x * weight;
  194. totalWeight += weight;
  195. weight = (0.25f * 0.75f) * GetDepthFactor(fullDepths.x, halfDepths.y);
  196. outputValues.x += sourceValues.y * weight;
  197. totalWeight += weight;
  198. weight = (0.75f * 0.25f) * GetDepthFactor(fullDepths.x, halfDepths.w);
  199. outputValues.x += sourceValues.w * weight;
  200. totalWeight += weight;
  201. weight = (0.25f * 0.25f) * GetDepthFactor(fullDepths.x, halfDepths.z);
  202. outputValues.x += sourceValues.z * weight;
  203. totalWeight += weight;
  204. outputValues.x /= totalWeight;
  205. }
  206. // To understand the -1, read the last paragraph of the Algorithm Overview section at the start
  207. uint2 outputPixel = mad(dispatch_id.xy, 2, -1);
  208. PassSrg::m_outputFullRes[outputPixel] = outputValues.w;
  209. ++outputPixel.x;
  210. PassSrg::m_outputFullRes[outputPixel] = outputValues.z;
  211. ++outputPixel.y;
  212. PassSrg::m_outputFullRes[outputPixel] = outputValues.y;
  213. --outputPixel.x;
  214. PassSrg::m_outputFullRes[outputPixel] = outputValues.x;
  215. }