SHDemo.azsl 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <viewsrg.srgi>
  9. #include <Atom/Features/SphericalHarmonicsUtility.azsli>
  10. ShaderResourceGroup SphericalHarmonicsInstanceSrg : SRG_PerObject
  11. {
  12. int m_shBand;
  13. int m_shOrder;
  14. int m_shSolver;
  15. bool m_enableDistortion;
  16. column_major float4x4 m_objectMatrix;
  17. }
  18. struct VSInput
  19. {
  20. float3 m_position : POSITION;
  21. float2 m_uv : UV0;
  22. };
  23. struct VSOutput
  24. {
  25. float4 m_position : SV_Position;
  26. float2 m_uv : UV0;
  27. };
  28. VSOutput MainVS(VSInput vsInput)
  29. {
  30. VSOutput OUT;
  31. OUT.m_position = mul(float4(vsInput.m_position, 1.0), SphericalHarmonicsInstanceSrg::m_objectMatrix);
  32. OUT.m_uv = vsInput.m_uv;
  33. return OUT;
  34. }
  35. struct PSOutput
  36. {
  37. float4 m_color : SV_Target0;
  38. };
  39. // compare against the length of marched point (i.e. distance from world origin) against the magnitude spherical harmonics
  40. // evaluation centered around the world origin (0, 0, 0), if "m_enableDistortion" is set to false it will compare with
  41. // a unit sphere at the origin with radius = 0.35
  42. float3 EvalMarchResultWS(float3 marchedPoint)
  43. {
  44. // distance between reached point (not necessarily hit) in this step and world origin
  45. float d = length(marchedPoint);
  46. // closest surface point on target unit sphere (could change each step)
  47. // used as sample to evaluate SH basis
  48. float3 samplePoint = marchedPoint / d;
  49. // ideally radius of SH basis at each point
  50. float r = 0.0;
  51. switch(int(SphericalHarmonicsInstanceSrg::m_shSolver))
  52. {
  53. case 0: r = SHBasisPoly3(SphericalHarmonicsInstanceSrg::m_shBand, SphericalHarmonicsInstanceSrg::m_shOrder, samplePoint); break;
  54. case 1: r = SHBasisNaive16(SphericalHarmonicsInstanceSrg::m_shBand, SphericalHarmonicsInstanceSrg::m_shOrder, samplePoint); break;
  55. case 2: r = SHBasisNaiveEx(SphericalHarmonicsInstanceSrg::m_shBand, SphericalHarmonicsInstanceSrg::m_shOrder, samplePoint); break;
  56. }
  57. float3 result = float3(0.0, 0.0, 0.0);
  58. if(SphericalHarmonicsInstanceSrg::m_enableDistortion)
  59. {
  60. result = float3(d - abs(r), sign(r), d);
  61. }
  62. else
  63. {
  64. // second element generate weight for color interpolation later based on the value of r
  65. // constants only used to tune the color, don't have special meaning
  66. result = float3(d - 0.35, -1.0 + 2.0*clamp(0.5 + 16.0*r,0.0,1.0), d);
  67. }
  68. // output contains:
  69. // x: distance away from closest surface point on target object
  70. // (distorted by SH value if controlValues.w is 0, otherwise target is a sphere at origin with radius = 0.35)
  71. // y: normalized sign of SH value for color interpolation
  72. // z: distance between reached point (not necessarily hit) in this step and world origin
  73. return float3(result.x, 0.5 + 0.5*result.y, result.z);
  74. }
  75. // This function marches a ray spawn from "rayOrigin" along the given direction "rayDir" and check if it
  76. // hits the shape of SH basis with given band and order which centered at world origin (0, 0, 0) at each step
  77. float3 RayMarchIntersectWS(float3 rayOrigin, float3 rayDir)
  78. {
  79. float3 result = float3(1e10, -1.0, 1.0);
  80. float maxTraceDepth = 10.0;
  81. float nextStep = 1.0;
  82. // length of ray
  83. float t = 0.0;
  84. // ray payload where:
  85. // x holds sign of result
  86. // y holds distance between hit point and axis origin (not used in this shader)
  87. float2 payload = float2(-1.0, -1.0);
  88. for(int i = 0; i < 600; i++)
  89. {
  90. // stop either next step is too small (hit) or ray length exceed depth limit (miss)
  91. if(nextStep < 0.001 || t > maxTraceDepth) break;
  92. float3 res = EvalMarchResultWS(rayOrigin + rayDir*t);
  93. nextStep = res.x;
  94. payload = res.yz;
  95. t += nextStep * 0.1;
  96. }
  97. if(t < maxTraceDepth && t < result.x)
  98. result = float3(t, payload.x, payload.y);
  99. return result;
  100. }
  101. // evaluate normal by computing first derivative on the edge by definition
  102. float3 EvalNormalWS(float3 pos)
  103. {
  104. float3 epsilon = float3(0.001, 0.0, 0.0);
  105. return normalize( float3(
  106. EvalMarchResultWS(pos+epsilon.xyy).x - EvalMarchResultWS(pos-epsilon.xyy).x,
  107. EvalMarchResultWS(pos+epsilon.yxy).x - EvalMarchResultWS(pos-epsilon.yxy).x,
  108. EvalMarchResultWS(pos+epsilon.yyx).x - EvalMarchResultWS(pos-epsilon.yyx).x
  109. ) );
  110. }
  111. // Ray marcher for SH visualisation, based on https://www.shadertoy.com/view/lsfXWH
  112. PSOutput MainPS(VSOutput psInput)
  113. {
  114. // all following calculations and coordinates are in world space
  115. PSOutput OUT;
  116. float3 hdrColor = float3(0.5, 0.5, 0.5);
  117. // inverse of rotation matrix is its transpose due to orthogonality
  118. // position of translation part won't affect result since it's not used here
  119. float4x4 invView = transpose(ViewSrg::m_viewMatrix);
  120. float2 recenteredUV = psInput.m_uv - float2(0.5, 0.5);
  121. // -1 because -z forward frame is used
  122. float3 viewSpaceRayDir = float3(recenteredUV.x, recenteredUV.y, -1.0);
  123. // world space ray origin & direction
  124. float3 rayDir = normalize(mul(invView, float4(viewSpaceRayDir, 0.0)).xyz);
  125. float3 rayOrigin = ViewSrg::m_worldPosition.xyz;
  126. // hit record include:
  127. // x: length of ray at intersection point, -1 if miss
  128. // y: sign of SH basis value, normalized to {0(negative), 1(positive)}
  129. // z: maginitude of SH basis value, larger value result in brighter color in visulisation
  130. float3 hitRecord = RayMarchIntersectWS(rayOrigin, rayDir);
  131. // only do shading if hit anything
  132. if(hitRecord.y > -0.5)
  133. {
  134. float3 pos = rayOrigin + rayDir * hitRecord.x;
  135. float3 normal = EvalNormalWS(pos);
  136. // interpolate two colors based on the sign of SH value
  137. float3 mat = 0.5*lerp( float3(0.2,0.4,0.5), float3(0.6,0.3,0.2), hitRecord.y );
  138. // tinted depth for simple visualisation, constants only used to tune the color, don't have special meaning
  139. hdrColor = mat * hitRecord.z * 5.0 + 0.2;
  140. }
  141. OUT.m_color = float4(hdrColor, 1.0);
  142. return OUT;
  143. }