mesh-shader-microsoft-sample.azsl 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. //*********************************************************
  2. //
  3. // Copyright (c) Microsoft. All rights reserved.
  4. // This code is licensed under the MIT License (MIT).
  5. // THIS CODE IS PROVIDED *AS IS* WITHOUT WARRANTY OF
  6. // ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING ANY
  7. // IMPLIED WARRANTIES OF FITNESS FOR A PARTICULAR
  8. // PURPOSE, MERCHANTABILITY, OR NON-INFRINGEMENT.
  9. //
  10. //*********************************************************
  11. struct Instance
  12. {
  13. float4x4 World;
  14. float4x4 WorldInvTrans;
  15. float Scale;
  16. uint Flags;
  17. };
  18. struct Constants
  19. {
  20. float4x4 View;
  21. float4x4 ViewProj;
  22. float4 Planes[6];
  23. float3 ViewPosition;
  24. uint HighlightedIndex;
  25. float3 CullViewPosition;
  26. uint SelectedIndex;
  27. uint DrawMeshlets;
  28. };
  29. struct MeshInfo
  30. {
  31. uint IndexSize;
  32. uint MeshletCount;
  33. uint LastMeshletVertCount;
  34. uint LastMeshletPrimCount;
  35. };
  36. struct Meshlet
  37. {
  38. uint VertCount;
  39. uint VertOffset;
  40. uint PrimCount;
  41. uint PrimOffset;
  42. };
  43. struct CullData
  44. {
  45. float4 BoundingSphere;
  46. uint NormalCone;
  47. float ApexOffset;
  48. };
  49. bool IsConeDegenerate(CullData c)
  50. {
  51. return (c.NormalCone >> 24) == 0xff;
  52. }
  53. float4 UnpackCone(uint packed)
  54. {
  55. float4 v;
  56. v.x = float((packed >> 0) & 0xFF);
  57. v.y = float((packed >> 8) & 0xFF);
  58. v.z = float((packed >> 16) & 0xFF);
  59. v.w = float((packed >> 24) & 0xFF);
  60. v = v / 255.0;
  61. v.xyz = v.xyz * 2.0 - 1.0;
  62. return v;
  63. }
  64. struct Vertex
  65. {
  66. float3 Position;
  67. float3 Normal;
  68. };
  69. struct VertexOut
  70. {
  71. float4 PositionHS : SV_Position;
  72. float3 PositionVS : POSITION0;
  73. float3 Normal : NORMAL0;
  74. uint MeshletIndex : COLOR0;
  75. };
  76. struct Payload
  77. {
  78. uint MeshletIndices[ 32 ];
  79. };
  80. ShaderResourceGroupSemantic slot1
  81. {
  82. FrequencyId = 1;
  83. };
  84. ShaderResourceGroup u_ : slot1
  85. {
  86. ConstantBuffer<Constants> GConstants;
  87. ConstantBuffer<MeshInfo> GMeshInfo;
  88. ConstantBuffer<Instance> GInstance;
  89. StructuredBuffer<Vertex> Vertices;
  90. StructuredBuffer<Meshlet> Meshlets;
  91. ByteAddressBuffer UniqueVertexIndices;
  92. StructuredBuffer<uint> PrimitiveIndices;
  93. StructuredBuffer<CullData> MeshletCullData;
  94. }
  95. float3 RotateVector(float3 v0, float3 axis, float angle)
  96. {
  97. float cs = cos(angle);
  98. return cs * v0 + sin(angle) * cross(axis, v0) + (1 - cs) * dot(axis, v0) * axis;
  99. }
  100. uint3 UnpackPrimitive(uint primitive) { return uint3(primitive & 0x3FF, (primitive >> 10) & 0x3FF, (primitive >> 20) & 0x3FF); }
  101. uint GetVertexIndex(Meshlet m, uint localIndex)
  102. {
  103. localIndex = m.VertOffset + localIndex;
  104. if (u_::GMeshInfo.IndexSize == 4)
  105. {
  106. return u_::UniqueVertexIndices.Load(localIndex * 4);
  107. }
  108. else
  109. {
  110. uint wordOffset = (localIndex & 0x1);
  111. uint byteOffset = (localIndex / 2) * 4;
  112. uint indexPair = u_::UniqueVertexIndices.Load(byteOffset);
  113. uint index = (indexPair >> (wordOffset * 16)) & 0xffff;
  114. return index;
  115. }
  116. }
  117. uint3 GetPrimitive(Meshlet m, uint index)
  118. {
  119. return UnpackPrimitive(u_::PrimitiveIndices[m.PrimOffset + index]);
  120. }
  121. VertexOut GetVertexAttributes(uint meshletIndex, uint vertexIndex)
  122. {
  123. Vertex v = u_::Vertices[vertexIndex];
  124. float4 positionWS = mul(float4(v.Position, 1), u_::GInstance.World);
  125. VertexOut vout;
  126. vout.PositionVS = mul(positionWS, u_::GConstants.View).xyz;
  127. vout.PositionHS = mul(positionWS, u_::GConstants.ViewProj);
  128. vout.Normal = mul(float4(v.Normal, 0), u_::GInstance.WorldInvTrans).xyz;
  129. vout.MeshletIndex = meshletIndex;
  130. return vout;
  131. }
  132. [RootSignature( "CBV(b0), CBV(b1), CBV(b2), SRV(t0), SRV(t1), SRV(t2), SRV(t3), SRV(t4)" )]
  133. [NumThreads(128, 1, 1)]
  134. [OutputTopology("triangle")]
  135. void main(
  136. uint dtid : SV_DispatchThreadID,
  137. uint gtid : SV_GroupThreadID,
  138. uint gid : SV_GroupID,
  139. in payload Payload a_payload,
  140. out vertices VertexOut verts[64],
  141. out indices uint3 tris[126]
  142. )
  143. {
  144. uint meshletIndex = a_payload.MeshletIndices[gid];
  145. if (meshletIndex >= u_::GMeshInfo.MeshletCount)
  146. return;
  147. Meshlet m = u_::Meshlets[meshletIndex];
  148. SetMeshOutputCounts(m.VertCount, m.PrimCount);
  149. if (gtid < m.VertCount)
  150. {
  151. uint vertexIndex = GetVertexIndex(m, gtid);
  152. verts[gtid] = GetVertexAttributes(meshletIndex, vertexIndex);
  153. }
  154. if (gtid < m.PrimCount)
  155. {
  156. tris[gtid] = GetPrimitive(m, gtid);
  157. }
  158. }