Forráskód Böngészése

Merge pull request #2336 from tex3d/merge-dxil-1-5

Merge Dxil 1.5 features: MeshShader, SamplerFeedback, RayQuery (DXR 1.1)
Tex Riddell 6 éve
szülő
commit
5268c51c06
100 módosított fájl, 7262 hozzáadás és 1064 törlés
  1. 307 237
      docs/DXIL.rst
  2. 124 18
      docs/SPIR-V.rst
  3. 1 1
      external/SPIRV-Headers
  4. 1 1
      external/SPIRV-Tools
  5. 161 14
      include/dxc/DXIL/DxilConstants.h
  6. 17 1
      include/dxc/DXIL/DxilFunctionProps.h
  7. 1365 1
      include/dxc/DXIL/DxilInstructions.h
  8. 36 0
      include/dxc/DXIL/DxilMetadataHelper.h
  9. 12 2
      include/dxc/DXIL/DxilModule.h
  10. 1 0
      include/dxc/DXIL/DxilOperations.h
  11. 10 1
      include/dxc/DXIL/DxilShaderFlags.h
  12. 3 1
      include/dxc/DXIL/DxilShaderModel.h
  13. 1 1
      include/dxc/DXIL/DxilSigPoint.h
  14. 78 52
      include/dxc/DXIL/DxilSigPoint.inl
  15. 4 2
      include/dxc/DXIL/DxilSignature.h
  16. 1 1
      include/dxc/DXIL/DxilSignatureElement.h
  17. 29 1
      include/dxc/DXIL/DxilTypeSystem.h
  18. 1 0
      include/dxc/DXIL/DxilUtil.h
  19. 91 74
      include/dxc/DxilContainer/DxilPipelineStateValidation.h
  20. 3 3
      include/dxc/HLSL/ComputeViewIdState.h
  21. 17 0
      include/dxc/HLSL/DxilValidation.h
  22. 17 0
      include/dxc/HLSL/HLOperations.h
  23. 4 4
      include/dxc/HLSL/ViewIDPipelineValidation.inl
  24. 45 0
      include/dxc/HlslIntrinsicOp.h
  25. 1 1
      include/dxc/Support/HLSLOptions.td
  26. 1 0
      include/dxc/Support/SPIRVOptions.h
  27. 4 1
      include/dxc/dxcapi.internal.h
  28. 201 7
      lib/DXIL/DxilMetadataHelper.cpp
  29. 74 4
      lib/DXIL/DxilModule.cpp
  30. 217 6
      lib/DXIL/DxilOperations.cpp
  31. 16 0
      lib/DXIL/DxilResource.cpp
  32. 15 7
      lib/DXIL/DxilResourceBase.cpp
  33. 1 0
      lib/DXIL/DxilSemantic.cpp
  34. 14 0
      lib/DXIL/DxilShaderFlags.cpp
  35. 8 3
      lib/DXIL/DxilShaderModel.cpp
  36. 12 1
      lib/DXIL/DxilSignature.cpp
  37. 2 2
      lib/DXIL/DxilSignatureElement.cpp
  38. 56 1
      lib/DXIL/DxilTypeSystem.cpp
  39. 30 8
      lib/DXIL/DxilUtil.cpp
  40. 71 25
      lib/DxilContainer/DxilContainerAssembler.cpp
  41. 56 63
      lib/DxilPIXPasses/DxilShaderAccessTracking.cpp
  42. 46 33
      lib/HLSL/ComputeViewIdState.cpp
  43. 48 17
      lib/HLSL/ComputeViewIdStateBuilder.cpp
  44. 11 18
      lib/HLSL/DxilContainerReflection.cpp
  45. 17 6
      lib/HLSL/DxilPreserveAllOutputs.cpp
  46. 467 43
      lib/HLSL/DxilValidation.cpp
  47. 2 0
      lib/HLSL/HLModule.cpp
  48. 348 2
      lib/HLSL/HLOperationLower.cpp
  49. 208 62
      lib/HLSL/HLSignatureLower.cpp
  50. 9 1
      lib/HLSL/HLSignatureLower.h
  51. 25 8
      lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp
  52. 7 0
      tools/clang/include/clang/AST/HlslTypes.h
  53. 24 0
      tools/clang/include/clang/Basic/Attr.td
  54. 10 4
      tools/clang/include/clang/Basic/DiagnosticSemaKinds.td
  55. 4 0
      tools/clang/include/clang/Basic/TokenKinds.def
  56. 1 0
      tools/clang/include/clang/SPIRV/FeatureManager.h
  57. 7 0
      tools/clang/include/clang/SPIRV/SpirvBuilder.h
  58. 4 0
      tools/clang/include/clang/SPIRV/SpirvContext.h
  59. 100 0
      tools/clang/lib/AST/ASTContextHLSL.cpp
  60. 5 0
      tools/clang/lib/AST/HlslTypes.cpp
  61. 276 26
      tools/clang/lib/CodeGen/CGHLSLMS.cpp
  62. 21 0
      tools/clang/lib/Parse/ParseDecl.cpp
  63. 10 0
      tools/clang/lib/Parse/ParseExpr.cpp
  64. 5 1
      tools/clang/lib/Parse/ParseStmt.cpp
  65. 4 0
      tools/clang/lib/Parse/ParseTentative.cpp
  66. 14 6
      tools/clang/lib/SPIRV/CapabilityVisitor.cpp
  67. 243 49
      tools/clang/lib/SPIRV/DeclResultIdMapper.cpp
  68. 20 5
      tools/clang/lib/SPIRV/DeclResultIdMapper.h
  69. 3 0
      tools/clang/lib/SPIRV/FeatureManager.cpp
  70. 70 41
      tools/clang/lib/SPIRV/GlPerVertex.cpp
  71. 5 4
      tools/clang/lib/SPIRV/GlPerVertex.h
  72. 17 0
      tools/clang/lib/SPIRV/SpirvBuilder.cpp
  73. 532 43
      tools/clang/lib/SPIRV/SpirvEmitter.cpp
  74. 33 2
      tools/clang/lib/SPIRV/SpirvEmitter.h
  75. 32 0
      tools/clang/lib/Sema/SemaExpr.cpp
  76. 3 1
      tools/clang/lib/Sema/SemaExprCXX.cpp
  77. 245 15
      tools/clang/lib/Sema/SemaHLSL.cpp
  78. 153 132
      tools/clang/lib/Sema/gen_intrin_main_tables_15.h
  79. 66 0
      tools/clang/test/CodeGenHLSL/batch/declarations/resources/textures/feedback.hlsl
  80. 1 1
      tools/clang/test/CodeGenHLSL/batch/expressions/intrinsics/misc/abs1.hlsl
  81. 19 0
      tools/clang/test/CodeGenHLSL/batch/shader_stages/raytracing/rayquery/tracerayinline.hlsl
  82. 205 0
      tools/clang/test/CodeGenHLSL/batch/shader_stages/raytracing/rayquery/tryAllOps.hlsl
  83. 39 0
      tools/clang/test/CodeGenHLSL/batch/shader_stages/raytracing/raytracing_anyhit_geometryIndex.hlsl
  84. 30 0
      tools/clang/test/CodeGenHLSL/batch/shader_stages/raytracing/raytracing_closesthit_geometryIndex.hlsl
  85. 21 0
      tools/clang/test/CodeGenHLSL/batch/shader_stages/raytracing/raytracing_intersection_geometryIndex.hlsl
  86. 22 0
      tools/clang/test/CodeGenHLSL/mesh-val/amplification.hlsl
  87. 22 0
      tools/clang/test/CodeGenHLSL/mesh-val/asOversizePayload.hlsl
  88. 78 0
      tools/clang/test/CodeGenHLSL/mesh-val/mesh.hlsl
  89. 21 0
      tools/clang/test/CodeGenHLSL/mesh-val/missingDispatchMesh.hlsl
  90. 62 0
      tools/clang/test/CodeGenHLSL/mesh-val/missingSetMeshOutputCounts.hlsl
  91. 63 0
      tools/clang/test/CodeGenHLSL/mesh-val/msOversizePayload.hlsl
  92. 23 0
      tools/clang/test/CodeGenHLSL/mesh-val/multipleDispatchMesh.hlsl
  93. 64 0
      tools/clang/test/CodeGenHLSL/mesh-val/multipleSetMeshOutputCounts.hlsl
  94. 24 0
      tools/clang/test/CodeGenHLSL/mesh-val/nonDominatingDispatchMesh.hlsl
  95. 63 0
      tools/clang/test/CodeGenHLSL/mesh-val/nonDominatingSetMeshOutputCounts.hlsl
  96. 74 0
      tools/clang/test/CodeGenHLSL/mesh-val/oversizeSM.hlsl
  97. 22 0
      tools/clang/test/CodeGenHLSL/mesh/amplification.hlsl
  98. 64 0
      tools/clang/test/CodeGenHLSL/mesh/illegalOutIndicesAssignment.hlsl
  99. 78 0
      tools/clang/test/CodeGenHLSL/mesh/mesh.hlsl
  100. 64 0
      tools/clang/test/CodeGenHLSL/mesh/multipleInPayload.hlsl

+ 307 - 237
docs/DXIL.rst

@@ -90,17 +90,19 @@ The shader model is specified as a named metadata in DXIL::
 
 The following values of <shaderModelName>_<major>_<minor> are supported:
 
-==================== ===================================== ===========
-Target               Legacy Models                         DXIL Models
-==================== ===================================== ===========
-Vertex shader (VS)   vs_4_0, vs_4_1, vs_5_0, vs_5_1        vs_6_0
-Hull shader (HS)     hs_5_0, hs_5_1                        hs_6_0
-Domain shader (DS)   ds_5_0, ds_5_1                        ds_6_0
-Geometry shader (GS) gs_4_0, gs_4_1, gs_5_0, gs_5_1        gs_6_0
-Pixel shader (PS)    ps_4_0, ps_4_1, ps_5_0, ps_5_1        ps_6_0
-Compute shader (CS)  cs_5_0 (cs_4_0 is mapped onto cs_5_0) cs_6_0
-Shader library       no support                            lib_6_1
-==================== ===================================== ===========
+====================      ===================================== ===========
+Target                    Legacy Models                         DXIL Models
+====================      ===================================== ===========
+Vertex shader (VS)        vs_4_0, vs_4_1, vs_5_0, vs_5_1        vs_6_0
+Hull shader (HS)          hs_5_0, hs_5_1                        hs_6_0
+Domain shader (DS)        ds_5_0, ds_5_1                        ds_6_0
+Geometry shader (GS)      gs_4_0, gs_4_1, gs_5_0, gs_5_1        gs_6_0
+Pixel shader (PS)         ps_4_0, ps_4_1, ps_5_0, ps_5_1        ps_6_0
+Compute shader (CS)       cs_5_0 (cs_4_0 is mapped onto cs_5_0) cs_6_0
+Shader library            no support                            lib_6_1
+Mesh shader (MS)          no support                            ms_6_5
+Amplification shader (AS) no support                            as_6_5
+========================= ===================================== ===========
 
 The DXIL verifier ensures that DXIL conforms to the specified shader model.
 
@@ -609,26 +611,30 @@ Signature Points are enumerated as follows in the SigPointKind
 .. <py::lines('SIGPOINT-RST')>hctdb_instrhelp.get_sigpoint_rst()</py>
 .. SIGPOINT-RST:BEGIN
 
-== ======== ======= ========== ============== ============= ============================================================================
-ID SigPoint Related ShaderKind PackingKind    SignatureKind Description
-== ======== ======= ========== ============== ============= ============================================================================
-0  VSIn     Invalid Vertex     InputAssembler Input         Ordinary Vertex Shader input from Input Assembler
-1  VSOut    Invalid Vertex     Vertex         Output        Ordinary Vertex Shader output that may feed Rasterizer
-2  PCIn     HSCPIn  Hull       None           Invalid       Patch Constant function non-patch inputs
-3  HSIn     HSCPIn  Hull       None           Invalid       Hull Shader function non-patch inputs
-4  HSCPIn   Invalid Hull       Vertex         Input         Hull Shader patch inputs - Control Points
-5  HSCPOut  Invalid Hull       Vertex         Output        Hull Shader function output - Control Point
-6  PCOut    Invalid Hull       PatchConstant  PatchConstant Patch Constant function output - Patch Constant data passed to Domain Shader
-7  DSIn     Invalid Domain     PatchConstant  PatchConstant Domain Shader regular input - Patch Constant data plus system values
-8  DSCPIn   Invalid Domain     Vertex         Input         Domain Shader patch input - Control Points
-9  DSOut    Invalid Domain     Vertex         Output        Domain Shader output - vertex data that may feed Rasterizer
-10 GSVIn    Invalid Geometry   Vertex         Input         Geometry Shader vertex input - qualified with primitive type
-11 GSIn     GSVIn   Geometry   None           Invalid       Geometry Shader non-vertex inputs (system values)
-12 GSOut    Invalid Geometry   Vertex         Output        Geometry Shader output - vertex data that may feed Rasterizer
-13 PSIn     Invalid Pixel      Vertex         Input         Pixel Shader input
-14 PSOut    Invalid Pixel      Target         Output        Pixel Shader output
-15 CSIn     Invalid Compute    None           Invalid       Compute Shader input
-== ======== ======= ========== ============== ============= ============================================================================
+== ======== ======= ============= ============== ================ ============================================================================
+ID SigPoint Related ShaderKind    PackingKind    SignatureKind    Description
+== ======== ======= ============= ============== ================ ============================================================================
+0  VSIn     Invalid Vertex        InputAssembler Input            Ordinary Vertex Shader input from Input Assembler
+1  VSOut    Invalid Vertex        Vertex         Output           Ordinary Vertex Shader output that may feed Rasterizer
+2  PCIn     HSCPIn  Hull          None           Invalid          Patch Constant function non-patch inputs
+3  HSIn     HSCPIn  Hull          None           Invalid          Hull Shader function non-patch inputs
+4  HSCPIn   Invalid Hull          Vertex         Input            Hull Shader patch inputs - Control Points
+5  HSCPOut  Invalid Hull          Vertex         Output           Hull Shader function output - Control Point
+6  PCOut    Invalid Hull          PatchConstant  PatchConstOrPrim Patch Constant function output - Patch Constant data passed to Domain Shader
+7  DSIn     Invalid Domain        PatchConstant  PatchConstOrPrim Domain Shader regular input - Patch Constant data plus system values
+8  DSCPIn   Invalid Domain        Vertex         Input            Domain Shader patch input - Control Points
+9  DSOut    Invalid Domain        Vertex         Output           Domain Shader output - vertex data that may feed Rasterizer
+10 GSVIn    Invalid Geometry      Vertex         Input            Geometry Shader vertex input - qualified with primitive type
+11 GSIn     GSVIn   Geometry      None           Invalid          Geometry Shader non-vertex inputs (system values)
+12 GSOut    Invalid Geometry      Vertex         Output           Geometry Shader output - vertex data that may feed Rasterizer
+13 PSIn     Invalid Pixel         Vertex         Input            Pixel Shader input
+14 PSOut    Invalid Pixel         Target         Output           Pixel Shader output
+15 CSIn     Invalid Compute       None           Invalid          Compute Shader input
+16 MSIn     Invalid Mesh          None           Invalid          Mesh Shader input
+17 MSOut    Invalid Mesh          Vertex         Output           Mesh Shader vertices output
+18 MSPOut   Invalid Mesh          Vertex         PatchConstOrPrim Mesh Shader primitives output
+19 ASIn     Invalid Amplification None           Invalid          Amplification Shader input
+== ======== ======= ============= ============== ================ ============================================================================
 
 .. SIGPOINT-RST:END
 
@@ -663,40 +669,41 @@ Semantic Interpretations for each SemanticKind at each SigPointKind are as follo
 .. <py::lines('SEMINT-TABLE-RST')>hctdb_instrhelp.get_sem_interpretation_table_rst()</py>
 .. SEMINT-TABLE-RST:BEGIN
 
-====================== ============ ======== ============ ============ ======== ======== ========== ============ ======== ======== ======== ============ ======== ============= ============= ========
-Semantic               VSIn         VSOut    PCIn         HSIn         HSCPIn   HSCPOut  PCOut      DSIn         DSCPIn   DSOut    GSVIn    GSIn         GSOut    PSIn          PSOut         CSIn
-====================== ============ ======== ============ ============ ======== ======== ========== ============ ======== ======== ======== ============ ======== ============= ============= ========
-Arbitrary              Arb          Arb      NA           NA           Arb      Arb      Arb        Arb          Arb      Arb      Arb      NA           Arb      Arb           NA            NA
-VertexID               SV           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NA            NA
-InstanceID             SV           Arb      NA           NA           Arb      Arb      NA         NA           Arb      Arb      Arb      NA           Arb      Arb           NA            NA
-Position               Arb          SV       NA           NA           SV       SV       Arb        Arb          SV       SV       SV       NA           SV       SV            NA            NA
-RenderTargetArrayIndex Arb          SV       NA           NA           SV       SV       Arb        Arb          SV       SV       SV       NA           SV       SV            NA            NA
-ViewPortArrayIndex     Arb          SV       NA           NA           SV       SV       Arb        Arb          SV       SV       SV       NA           SV       SV            NA            NA
-ClipDistance           Arb          ClipCull NA           NA           ClipCull ClipCull Arb        Arb          ClipCull ClipCull ClipCull NA           ClipCull ClipCull      NA            NA
-CullDistance           Arb          ClipCull NA           NA           ClipCull ClipCull Arb        Arb          ClipCull ClipCull ClipCull NA           ClipCull ClipCull      NA            NA
-OutputControlPointID   NA           NA       NA           NotInSig     NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NA            NA
-DomainLocation         NA           NA       NA           NA           NA       NA       NA         NotInSig     NA       NA       NA       NA           NA       NA            NA            NA
-PrimitiveID            NA           NA       NotInSig     NotInSig     NA       NA       NA         NotInSig     NA       NA       NA       Shadow       SGV      SGV           NA            NA
-GSInstanceID           NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NotInSig     NA       NA            NA            NA
-SampleIndex            NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       Shadow _41    NA            NA
-IsFrontFace            NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           SGV      SGV           NA            NA
-Coverage               NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NotInSig _50  NotPacked _41 NA
-InnerCoverage          NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NotInSig _50  NA            NA
-Target                 NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            Target        NA
-Depth                  NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NotPacked     NA
-DepthLessEqual         NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NotPacked _50 NA
-DepthGreaterEqual      NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NotPacked _50 NA
-StencilRef             NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NotPacked _50 NA
-DispatchThreadID       NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NA            NotInSig
-GroupID                NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NA            NotInSig
-GroupIndex             NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NA            NotInSig
-GroupThreadID          NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NA            NotInSig
-TessFactor             NA           NA       NA           NA           NA       NA       TessFactor TessFactor   NA       NA       NA       NA           NA       NA            NA            NA
-InsideTessFactor       NA           NA       NA           NA           NA       NA       TessFactor TessFactor   NA       NA       NA       NA           NA       NA            NA            NA
-ViewID                 NotInSig _61 NA       NotInSig _61 NotInSig _61 NA       NA       NA         NotInSig _61 NA       NA       NA       NotInSig _61 NA       NotInSig _61  NA            NA
-Barycentrics           NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NotPacked _61 NA            NA
-ShadingRate            NA           SV _64   NA           NA           SV _64   SV _64   NA         NA           SV _64   SV _64   SV _64   NA           SV _64   SV _64        NA            NA
-====================== ============ ======== ============ ============ ======== ======== ========== ============ ======== ======== ======== ============ ======== ============= ============= ========
+====================== ============ ======== ============ ============ ======== ======== ========== ============ ======== ======== ======== ============ ======== ============= ============= ======== ============ ============ ======= ============
+Semantic               VSIn         VSOut    PCIn         HSIn         HSCPIn   HSCPOut  PCOut      DSIn         DSCPIn   DSOut    GSVIn    GSIn         GSOut    PSIn          PSOut         CSIn     MSIn         MSOut        MSPOut  ASIn
+====================== ============ ======== ============ ============ ======== ======== ========== ============ ======== ======== ======== ============ ======== ============= ============= ======== ============ ============ ======= ============
+Arbitrary              Arb          Arb      NA           NA           Arb      Arb      Arb        Arb          Arb      Arb      Arb      NA           Arb      Arb           NA            NA       NA           Arb _65      Arb _65 NA
+VertexID               SV           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NA            NA       NA           NA           NA      NA
+InstanceID             SV           Arb      NA           NA           Arb      Arb      NA         NA           Arb      Arb      Arb      NA           Arb      Arb           NA            NA       NA           NA           NA      NA
+Position               Arb          SV       NA           NA           SV       SV       Arb        Arb          SV       SV       SV       NA           SV       SV            NA            NA       NA           SV _65       NA      NA
+RenderTargetArrayIndex Arb          SV       NA           NA           SV       SV       Arb        Arb          SV       SV       SV       NA           SV       SV            NA            NA       NA           NA           SV _65  NA
+ViewPortArrayIndex     Arb          SV       NA           NA           SV       SV       Arb        Arb          SV       SV       SV       NA           SV       SV            NA            NA       NA           NA           SV _65  NA
+ClipDistance           Arb          ClipCull NA           NA           ClipCull ClipCull Arb        Arb          ClipCull ClipCull ClipCull NA           ClipCull ClipCull      NA            NA       NA           ClipCull _65 NA      NA
+CullDistance           Arb          ClipCull NA           NA           ClipCull ClipCull Arb        Arb          ClipCull ClipCull ClipCull NA           ClipCull ClipCull      NA            NA       NA           ClipCull _65 NA      NA
+OutputControlPointID   NA           NA       NA           NotInSig     NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NA            NA       NA           NA           NA      NA
+DomainLocation         NA           NA       NA           NA           NA       NA       NA         NotInSig     NA       NA       NA       NA           NA       NA            NA            NA       NA           NA           NA      NA
+PrimitiveID            NA           NA       NotInSig     NotInSig     NA       NA       NA         NotInSig     NA       NA       NA       Shadow       SGV      SGV           NA            NA       NA           NA           SV _65  NA
+GSInstanceID           NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NotInSig     NA       NA            NA            NA       NA           NA           NA      NA
+SampleIndex            NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       Shadow _41    NA            NA       NA           NA           NA      NA
+IsFrontFace            NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           SGV      SGV           NA            NA       NA           NA           NA      NA
+Coverage               NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NotInSig _50  NotPacked _41 NA       NA           NA           NA      NA
+InnerCoverage          NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NotInSig _50  NA            NA       NA           NA           NA      NA
+Target                 NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            Target        NA       NA           NA           NA      NA
+Depth                  NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NotPacked     NA       NA           NA           NA      NA
+DepthLessEqual         NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NotPacked _50 NA       NA           NA           NA      NA
+DepthGreaterEqual      NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NotPacked _50 NA       NA           NA           NA      NA
+StencilRef             NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NotPacked _50 NA       NA           NA           NA      NA
+DispatchThreadID       NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NA            NotInSig NotInSig _65 NA           NA      NotInSig _65
+GroupID                NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NA            NotInSig NotInSig _65 NA           NA      NotInSig _65
+GroupIndex             NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NA            NotInSig NotInSig _65 NA           NA      NotInSig _65
+GroupThreadID          NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NA            NotInSig NotInSig _65 NA           NA      NotInSig _65
+TessFactor             NA           NA       NA           NA           NA       NA       TessFactor TessFactor   NA       NA       NA       NA           NA       NA            NA            NA       NA           NA           NA      NA
+InsideTessFactor       NA           NA       NA           NA           NA       NA       TessFactor TessFactor   NA       NA       NA       NA           NA       NA            NA            NA       NA           NA           NA      NA
+ViewID                 NotInSig _61 NA       NotInSig _61 NotInSig _61 NA       NA       NA         NotInSig _61 NA       NA       NA       NotInSig _61 NA       NotInSig _61  NA            NA       NotInSig _65 NA           NA      NA
+Barycentrics           NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NotPacked _61 NA            NA       NA           NA           NA      NA
+ShadingRate            NA           SV _64   NA           NA           SV _64   SV _64   NA         NA           SV _64   SV _64   SV _64   NA           SV _64   SV _64        NA            NA       NA           NA           NA      NA
+CullPrimitive          NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NA            NA       NA           NA           SV _65  NA
+====================== ============ ======== ============ ============ ======== ======== ========== ============ ======== ======== ======== ============ ======== ============= ============= ======== ============ ============ ======= ============
 
 .. SEMINT-TABLE-RST:END
 
@@ -2085,178 +2092,224 @@ Opcodes are defined on a dense range and will be provided as enum in a header fi
 .. <py::lines('OPCODES-RST')>hctdb_instrhelp.get_opcodes_rst()</py>
 .. OPCODES-RST:BEGIN
 
-=== ============================= =======================================================================================================================================================================================================================
-ID  Name                          Description
-=== ============================= =======================================================================================================================================================================================================================
-0   TempRegLoad_                  Helper load operation
-1   TempRegStore_                 Helper store operation
-2   MinPrecXRegLoad_              Helper load operation for minprecision
-3   MinPrecXRegStore_             Helper store operation for minprecision
-4   LoadInput_                    Loads the value from shader input
-5   StoreOutput_                  Stores the value to shader output
-6   FAbs_                         returns the absolute value of the input value.
-7   Saturate_                     clamps the result of a single or double precision floating point value to [0.0f...1.0f]
-8   IsNaN_                        Returns true if x is NAN or QNAN, false otherwise.
-9   IsInf_                        Returns true if x is +INF or -INF, false otherwise.
-10  IsFinite_                     Returns true if x is finite, false otherwise.
-11  IsNormal_                     returns IsNormal
-12  Cos_                          returns cosine(theta) for theta in radians.
-13  Sin_                          returns sine(theta) for theta in radians.
-14  Tan_                          returns tan(theta) for theta in radians.
-15  Acos_                         Returns the arccosine of the specified value. Input should be a floating-point value within the range of -1 to 1.
-16  Asin_                         Returns the arccosine of the specified value. Input should be a floating-point value within the range of -1 to 1
-17  Atan_                         Returns the arctangent of the specified value. The return value is within the range of -PI/2 to PI/2.
-18  Hcos_                         returns the hyperbolic cosine of the specified value.
-19  Hsin_                         returns the hyperbolic sine of the specified value.
-20  Htan_                         returns the hyperbolic tangent of the specified value.
-21  Exp_                          returns 2^exponent
-22  Frc_                          extract fracitonal component.
-23  Log_                          returns log base 2.
-24  Sqrt_                         returns square root
-25  Rsqrt_                        returns reciprocal square root (1 / sqrt(src)
-26  Round_ne_                     floating-point round to integral float.
-27  Round_ni_                     floating-point round to integral float.
-28  Round_pi_                     floating-point round to integral float.
-29  Round_z_                      floating-point round to integral float.
-30  Bfrev_                        Reverses the order of the bits.
-31  Countbits_                    Counts the number of bits in the input integer.
-32  FirstbitLo_                   Returns the location of the first set bit starting from the lowest order bit and working upward.
-33  FirstbitHi_                   Returns the location of the first set bit starting from the highest order bit and working downward.
-34  FirstbitSHi_                  Returns the location of the first set bit from the highest order bit based on the sign.
-35  FMax_                         returns a if a >= b, else b
-36  FMin_                         returns a if a < b, else b
-37  IMax_                         IMax(a,b) returns a if a > b, else b
-38  IMin_                         IMin(a,b) returns a if a < b, else b
-39  UMax_                         unsigned integer maximum. UMax(a,b) = a > b ? a : b
-40  UMin_                         unsigned integer minimum. UMin(a,b) = a < b ? a : b
-41  IMul_                         multiply of 32-bit operands to produce the correct full 64-bit result.
-42  UMul_                         multiply of 32-bit operands to produce the correct full 64-bit result.
-43  UDiv_                         unsigned divide of the 32-bit operand src0 by the 32-bit operand src1.
-44  UAddc_                        unsigned add of 32-bit operand with the carry
-45  USubb_                        unsigned subtract of 32-bit operands with the borrow
-46  FMad_                         floating point multiply & add
-47  Fma_                          fused multiply-add
-48  IMad_                         Signed integer multiply & add
-49  UMad_                         Unsigned integer multiply & add
-50  Msad_                         masked Sum of Absolute Differences.
-51  Ibfe_                         Integer bitfield extract
-52  Ubfe_                         Unsigned integer bitfield extract
-53  Bfi_                          Given a bit range from the LSB of a number, places that number of bits in another number at any offset
-54  Dot2_                         Two-dimensional vector dot-product
-55  Dot3_                         Three-dimensional vector dot-product
-56  Dot4_                         Four-dimensional vector dot-product
-57  CreateHandle                  creates the handle to a resource
-58  CBufferLoad                   loads a value from a constant buffer resource
-59  CBufferLoadLegacy             loads a value from a constant buffer resource
-60  Sample                        samples a texture
-61  SampleBias                    samples a texture after applying the input bias to the mipmap level
-62  SampleLevel                   samples a texture using a mipmap-level offset
-63  SampleGrad                    samples a texture using a gradient to influence the way the sample location is calculated
-64  SampleCmp                     samples a texture and compares a single component against the specified comparison value
-65  SampleCmpLevelZero            samples a texture and compares a single component against the specified comparison value
-66  TextureLoad                   reads texel data without any filtering or sampling
-67  TextureStore                  reads texel data without any filtering or sampling
-68  BufferLoad                    reads from a TypedBuffer
-69  BufferStore                   writes to a RWTypedBuffer
-70  BufferUpdateCounter           atomically increments/decrements the hidden 32-bit counter stored with a Count or Append UAV
-71  CheckAccessFullyMapped        determines whether all values from a Sample, Gather, or Load operation accessed mapped tiles in a tiled resource
-72  GetDimensions                 gets texture size information
-73  TextureGather                 gathers the four texels that would be used in a bi-linear filtering operation
-74  TextureGatherCmp              same as TextureGather, except this instrution performs comparison on texels, similar to SampleCmp
-75  Texture2DMSGetSamplePosition  gets the position of the specified sample
-76  RenderTargetGetSamplePosition gets the position of the specified sample
-77  RenderTargetGetSampleCount    gets the number of samples for a render target
-78  AtomicBinOp                   performs an atomic operation on two operands
-79  AtomicCompareExchange         atomic compare and exchange to memory
-80  Barrier                       inserts a memory barrier in the shader
-81  CalculateLOD                  calculates the level of detail
-82  Discard                       discard the current pixel
-83  DerivCoarseX_                 computes the rate of change per stamp in x direction.
-84  DerivCoarseY_                 computes the rate of change per stamp in y direction.
-85  DerivFineX_                   computes the rate of change per pixel in x direction.
-86  DerivFineY_                   computes the rate of change per pixel in y direction.
-87  EvalSnapped                   evaluates an input attribute at pixel center with an offset
-88  EvalSampleIndex               evaluates an input attribute at a sample location
-89  EvalCentroid                  evaluates an input attribute at pixel center
-90  SampleIndex                   returns the sample index in a sample-frequency pixel shader
-91  Coverage                      returns the coverage mask input in a pixel shader
-92  InnerCoverage                 returns underestimated coverage input from conservative rasterization in a pixel shader
-93  ThreadId                      reads the thread ID
-94  GroupId                       reads the group ID (SV_GroupID)
-95  ThreadIdInGroup               reads the thread ID within the group (SV_GroupThreadID)
-96  FlattenedThreadIdInGroup      provides a flattened index for a given thread within a given group (SV_GroupIndex)
-97  EmitStream                    emits a vertex to a given stream
-98  CutStream                     completes the current primitive topology at the specified stream
-99  EmitThenCutStream             equivalent to an EmitStream followed by a CutStream
-100 GSInstanceID                  GSInstanceID
-101 MakeDouble                    creates a double value
-102 SplitDouble                   splits a double into low and high parts
-103 LoadOutputControlPoint        LoadOutputControlPoint
-104 LoadPatchConstant             LoadPatchConstant
-105 DomainLocation                DomainLocation
-106 StorePatchConstant            StorePatchConstant
-107 OutputControlPointID          OutputControlPointID
-108 PrimitiveID                   PrimitiveID
-109 CycleCounterLegacy            CycleCounterLegacy
-110 WaveIsFirstLane               returns 1 for the first lane in the wave
-111 WaveGetLaneIndex              returns the index of the current lane in the wave
-112 WaveGetLaneCount              returns the number of lanes in the wave
-113 WaveAnyTrue                   returns 1 if any of the lane evaluates the value to true
-114 WaveAllTrue                   returns 1 if all the lanes evaluate the value to true
-115 WaveActiveAllEqual            returns 1 if all the lanes have the same value
-116 WaveActiveBallot              returns a struct with a bit set for each lane where the condition is true
-117 WaveReadLaneAt                returns the value from the specified lane
-118 WaveReadLaneFirst             returns the value from the first lane
-119 WaveActiveOp                  returns the result the operation across waves
-120 WaveActiveBit                 returns the result of the operation across all lanes
-121 WavePrefixOp                  returns the result of the operation on prior lanes
-122 QuadReadLaneAt                reads from a lane in the quad
-123 QuadOp                        returns the result of a quad-level operation
-124 BitcastI16toF16               bitcast between different sizes
-125 BitcastF16toI16               bitcast between different sizes
-126 BitcastI32toF32               bitcast between different sizes
-127 BitcastF32toI32               bitcast between different sizes
-128 BitcastI64toF64               bitcast between different sizes
-129 BitcastF64toI64               bitcast between different sizes
-130 LegacyF32ToF16                legacy fuction to convert float (f32) to half (f16) (this is not related to min-precision)
-131 LegacyF16ToF32                legacy fuction to convert half (f16) to float (f32) (this is not related to min-precision)
-132 LegacyDoubleToFloat           legacy fuction to convert double to float
-133 LegacyDoubleToSInt32          legacy fuction to convert double to int32
-134 LegacyDoubleToUInt32          legacy fuction to convert double to uint32
-135 WaveAllBitCount               returns the count of bits set to 1 across the wave
-136 WavePrefixBitCount            returns the count of bits set to 1 on prior lanes
-137 AttributeAtVertex_            returns the values of the attributes at the vertex.
-138 ViewID                        returns the view index
-139 RawBufferLoad                 reads from a raw buffer and structured buffer
-140 RawBufferStore                writes to a RWByteAddressBuffer or RWStructuredBuffer
-141 InstanceID                    The user-provided InstanceID on the bottom-level acceleration structure instance within the top-level structure
-142 InstanceIndex                 The autogenerated index of the current instance in the top-level structure
-143 HitKind                       Returns the value passed as HitKind in ReportIntersection().  If intersection was reported by fixed-function triangle intersection, HitKind will be one of HIT_KIND_TRIANGLE_FRONT_FACE or HIT_KIND_TRIANGLE_BACK_FACE.
-144 RayFlags                      uint containing the current ray flags.
-145 DispatchRaysIndex             The current x and y location within the Width and Height
-146 DispatchRaysDimensions        The Width and Height values from the D3D12_DISPATCH_RAYS_DESC structure provided to the originating DispatchRays() call.
-147 WorldRayOrigin                The world-space origin for the current ray.
-148 WorldRayDirection             The world-space direction for the current ray.
-149 ObjectRayOrigin               Object-space origin for the current ray.
-150 ObjectRayDirection            Object-space direction for the current ray.
-151 ObjectToWorld                 Matrix for transforming from object-space to world-space.
-152 WorldToObject                 Matrix for transforming from world-space to object-space.
-153 RayTMin                       float representing the parametric starting point for the ray.
-154 RayTCurrent                   float representing the current parametric ending point for the ray
-155 IgnoreHit                     Used in an any hit shader to reject an intersection and terminate the shader
-156 AcceptHitAndEndSearch         Used in an any hit shader to abort the ray query and the intersection shader (if any). The current hit is committed and execution passes to the closest hit shader with the closest hit recorded so far
-157 TraceRay                      returns the view index
-158 ReportHit                     returns true if hit was accepted
-159 CallShader                    Call a shader in the callable shader table supplied through the DispatchRays() API
-160 CreateHandleForLib            create resource handle from resource struct for library
-161 PrimitiveIndex                PrimitiveIndex for raytracing shaders
-162 Dot2AddHalf                   2D half dot product with accumulate to float
-163 Dot4AddI8Packed               signed dot product of 4 x i8 vectors packed into i32, with accumulate to i32
-164 Dot4AddU8Packed               unsigned dot product of 4 x u8 vectors packed into i32, with accumulate to i32
-165 WaveMatch                     returns the bitmask of active lanes that have the same value
-166 WaveMultiPrefixOp             returns the result of the operation on groups of lanes identified by a bitmask
-167 WaveMultiPrefixBitCount       returns the count of bits set to 1 on groups of lanes identified by a bitmask
-=== ============================= =======================================================================================================================================================================================================================
+=== ============================================== =======================================================================================================================================================================================================================
+ID  Name                                           Description
+=== ============================================== =======================================================================================================================================================================================================================
+0   TempRegLoad_                                   Helper load operation
+1   TempRegStore_                                  Helper store operation
+2   MinPrecXRegLoad_                               Helper load operation for minprecision
+3   MinPrecXRegStore_                              Helper store operation for minprecision
+4   LoadInput_                                     Loads the value from shader input
+5   StoreOutput_                                   Stores the value to shader output
+6   FAbs_                                          returns the absolute value of the input value.
+7   Saturate_                                      clamps the result of a single or double precision floating point value to [0.0f...1.0f]
+8   IsNaN_                                         Returns true if x is NAN or QNAN, false otherwise.
+9   IsInf_                                         Returns true if x is +INF or -INF, false otherwise.
+10  IsFinite_                                      Returns true if x is finite, false otherwise.
+11  IsNormal_                                      returns IsNormal
+12  Cos_                                           returns cosine(theta) for theta in radians.
+13  Sin_                                           returns sine(theta) for theta in radians.
+14  Tan_                                           returns tan(theta) for theta in radians.
+15  Acos_                                          Returns the arccosine of the specified value. Input should be a floating-point value within the range of -1 to 1.
+16  Asin_                                          Returns the arccosine of the specified value. Input should be a floating-point value within the range of -1 to 1
+17  Atan_                                          Returns the arctangent of the specified value. The return value is within the range of -PI/2 to PI/2.
+18  Hcos_                                          returns the hyperbolic cosine of the specified value.
+19  Hsin_                                          returns the hyperbolic sine of the specified value.
+20  Htan_                                          returns the hyperbolic tangent of the specified value.
+21  Exp_                                           returns 2^exponent
+22  Frc_                                           extract fracitonal component.
+23  Log_                                           returns log base 2.
+24  Sqrt_                                          returns square root
+25  Rsqrt_                                         returns reciprocal square root (1 / sqrt(src)
+26  Round_ne_                                      floating-point round to integral float.
+27  Round_ni_                                      floating-point round to integral float.
+28  Round_pi_                                      floating-point round to integral float.
+29  Round_z_                                       floating-point round to integral float.
+30  Bfrev_                                         Reverses the order of the bits.
+31  Countbits_                                     Counts the number of bits in the input integer.
+32  FirstbitLo_                                    Returns the location of the first set bit starting from the lowest order bit and working upward.
+33  FirstbitHi_                                    Returns the location of the first set bit starting from the highest order bit and working downward.
+34  FirstbitSHi_                                   Returns the location of the first set bit from the highest order bit based on the sign.
+35  FMax_                                          returns a if a >= b, else b
+36  FMin_                                          returns a if a < b, else b
+37  IMax_                                          IMax(a,b) returns a if a > b, else b
+38  IMin_                                          IMin(a,b) returns a if a < b, else b
+39  UMax_                                          unsigned integer maximum. UMax(a,b) = a > b ? a : b
+40  UMin_                                          unsigned integer minimum. UMin(a,b) = a < b ? a : b
+41  IMul_                                          multiply of 32-bit operands to produce the correct full 64-bit result.
+42  UMul_                                          multiply of 32-bit operands to produce the correct full 64-bit result.
+43  UDiv_                                          unsigned divide of the 32-bit operand src0 by the 32-bit operand src1.
+44  UAddc_                                         unsigned add of 32-bit operand with the carry
+45  USubb_                                         unsigned subtract of 32-bit operands with the borrow
+46  FMad_                                          floating point multiply & add
+47  Fma_                                           fused multiply-add
+48  IMad_                                          Signed integer multiply & add
+49  UMad_                                          Unsigned integer multiply & add
+50  Msad_                                          masked Sum of Absolute Differences.
+51  Ibfe_                                          Integer bitfield extract
+52  Ubfe_                                          Unsigned integer bitfield extract
+53  Bfi_                                           Given a bit range from the LSB of a number, places that number of bits in another number at any offset
+54  Dot2_                                          Two-dimensional vector dot-product
+55  Dot3_                                          Three-dimensional vector dot-product
+56  Dot4_                                          Four-dimensional vector dot-product
+57  CreateHandle                                   creates the handle to a resource
+58  CBufferLoad                                    loads a value from a constant buffer resource
+59  CBufferLoadLegacy                              loads a value from a constant buffer resource
+60  Sample                                         samples a texture
+61  SampleBias                                     samples a texture after applying the input bias to the mipmap level
+62  SampleLevel                                    samples a texture using a mipmap-level offset
+63  SampleGrad                                     samples a texture using a gradient to influence the way the sample location is calculated
+64  SampleCmp                                      samples a texture and compares a single component against the specified comparison value
+65  SampleCmpLevelZero                             samples a texture and compares a single component against the specified comparison value
+66  TextureLoad                                    reads texel data without any filtering or sampling
+67  TextureStore                                   reads texel data without any filtering or sampling
+68  BufferLoad                                     reads from a TypedBuffer
+69  BufferStore                                    writes to a RWTypedBuffer
+70  BufferUpdateCounter                            atomically increments/decrements the hidden 32-bit counter stored with a Count or Append UAV
+71  CheckAccessFullyMapped                         determines whether all values from a Sample, Gather, or Load operation accessed mapped tiles in a tiled resource
+72  GetDimensions                                  gets texture size information
+73  TextureGather                                  gathers the four texels that would be used in a bi-linear filtering operation
+74  TextureGatherCmp                               same as TextureGather, except this instrution performs comparison on texels, similar to SampleCmp
+75  Texture2DMSGetSamplePosition                   gets the position of the specified sample
+76  RenderTargetGetSamplePosition                  gets the position of the specified sample
+77  RenderTargetGetSampleCount                     gets the number of samples for a render target
+78  AtomicBinOp                                    performs an atomic operation on two operands
+79  AtomicCompareExchange                          atomic compare and exchange to memory
+80  Barrier                                        inserts a memory barrier in the shader
+81  CalculateLOD                                   calculates the level of detail
+82  Discard                                        discard the current pixel
+83  DerivCoarseX_                                  computes the rate of change per stamp in x direction.
+84  DerivCoarseY_                                  computes the rate of change per stamp in y direction.
+85  DerivFineX_                                    computes the rate of change per pixel in x direction.
+86  DerivFineY_                                    computes the rate of change per pixel in y direction.
+87  EvalSnapped                                    evaluates an input attribute at pixel center with an offset
+88  EvalSampleIndex                                evaluates an input attribute at a sample location
+89  EvalCentroid                                   evaluates an input attribute at pixel center
+90  SampleIndex                                    returns the sample index in a sample-frequency pixel shader
+91  Coverage                                       returns the coverage mask input in a pixel shader
+92  InnerCoverage                                  returns underestimated coverage input from conservative rasterization in a pixel shader
+93  ThreadId                                       reads the thread ID
+94  GroupId                                        reads the group ID (SV_GroupID)
+95  ThreadIdInGroup                                reads the thread ID within the group (SV_GroupThreadID)
+96  FlattenedThreadIdInGroup                       provides a flattened index for a given thread within a given group (SV_GroupIndex)
+97  EmitStream                                     emits a vertex to a given stream
+98  CutStream                                      completes the current primitive topology at the specified stream
+99  EmitThenCutStream                              equivalent to an EmitStream followed by a CutStream
+100 GSInstanceID                                   GSInstanceID
+101 MakeDouble                                     creates a double value
+102 SplitDouble                                    splits a double into low and high parts
+103 LoadOutputControlPoint                         LoadOutputControlPoint
+104 LoadPatchConstant                              LoadPatchConstant
+105 DomainLocation                                 DomainLocation
+106 StorePatchConstant                             StorePatchConstant
+107 OutputControlPointID                           OutputControlPointID
+108 PrimitiveID                                    PrimitiveID
+109 CycleCounterLegacy                             CycleCounterLegacy
+110 WaveIsFirstLane                                returns 1 for the first lane in the wave
+111 WaveGetLaneIndex                               returns the index of the current lane in the wave
+112 WaveGetLaneCount                               returns the number of lanes in the wave
+113 WaveAnyTrue                                    returns 1 if any of the lane evaluates the value to true
+114 WaveAllTrue                                    returns 1 if all the lanes evaluate the value to true
+115 WaveActiveAllEqual                             returns 1 if all the lanes have the same value
+116 WaveActiveBallot                               returns a struct with a bit set for each lane where the condition is true
+117 WaveReadLaneAt                                 returns the value from the specified lane
+118 WaveReadLaneFirst                              returns the value from the first lane
+119 WaveActiveOp                                   returns the result the operation across waves
+120 WaveActiveBit                                  returns the result of the operation across all lanes
+121 WavePrefixOp                                   returns the result of the operation on prior lanes
+122 QuadReadLaneAt                                 reads from a lane in the quad
+123 QuadOp                                         returns the result of a quad-level operation
+124 BitcastI16toF16                                bitcast between different sizes
+125 BitcastF16toI16                                bitcast between different sizes
+126 BitcastI32toF32                                bitcast between different sizes
+127 BitcastF32toI32                                bitcast between different sizes
+128 BitcastI64toF64                                bitcast between different sizes
+129 BitcastF64toI64                                bitcast between different sizes
+130 LegacyF32ToF16                                 legacy fuction to convert float (f32) to half (f16) (this is not related to min-precision)
+131 LegacyF16ToF32                                 legacy fuction to convert half (f16) to float (f32) (this is not related to min-precision)
+132 LegacyDoubleToFloat                            legacy fuction to convert double to float
+133 LegacyDoubleToSInt32                           legacy fuction to convert double to int32
+134 LegacyDoubleToUInt32                           legacy fuction to convert double to uint32
+135 WaveAllBitCount                                returns the count of bits set to 1 across the wave
+136 WavePrefixBitCount                             returns the count of bits set to 1 on prior lanes
+137 AttributeAtVertex_                             returns the values of the attributes at the vertex.
+138 ViewID                                         returns the view index
+139 RawBufferLoad                                  reads from a raw buffer and structured buffer
+140 RawBufferStore                                 writes to a RWByteAddressBuffer or RWStructuredBuffer
+141 InstanceID                                     The user-provided InstanceID on the bottom-level acceleration structure instance within the top-level structure
+142 InstanceIndex                                  The autogenerated index of the current instance in the top-level structure
+143 HitKind                                        Returns the value passed as HitKind in ReportIntersection().  If intersection was reported by fixed-function triangle intersection, HitKind will be one of HIT_KIND_TRIANGLE_FRONT_FACE or HIT_KIND_TRIANGLE_BACK_FACE.
+144 RayFlags                                       uint containing the current ray flags.
+145 DispatchRaysIndex                              The current x and y location within the Width and Height
+146 DispatchRaysDimensions                         The Width and Height values from the D3D12_DISPATCH_RAYS_DESC structure provided to the originating DispatchRays() call.
+147 WorldRayOrigin                                 The world-space origin for the current ray.
+148 WorldRayDirection                              The world-space direction for the current ray.
+149 ObjectRayOrigin                                Object-space origin for the current ray.
+150 ObjectRayDirection                             Object-space direction for the current ray.
+151 ObjectToWorld                                  Matrix for transforming from object-space to world-space.
+152 WorldToObject                                  Matrix for transforming from world-space to object-space.
+153 RayTMin                                        float representing the parametric starting point for the ray.
+154 RayTCurrent                                    float representing the current parametric ending point for the ray
+155 IgnoreHit                                      Used in an any hit shader to reject an intersection and terminate the shader
+156 AcceptHitAndEndSearch                          Used in an any hit shader to abort the ray query and the intersection shader (if any). The current hit is committed and execution passes to the closest hit shader with the closest hit recorded so far
+157 TraceRay                                       initiates raytrace
+158 ReportHit                                      returns true if hit was accepted
+159 CallShader                                     Call a shader in the callable shader table supplied through the DispatchRays() API
+160 CreateHandleForLib                             create resource handle from resource struct for library
+161 PrimitiveIndex                                 PrimitiveIndex for raytracing shaders
+162 Dot2AddHalf                                    2D half dot product with accumulate to float
+163 Dot4AddI8Packed                                signed dot product of 4 x i8 vectors packed into i32, with accumulate to i32
+164 Dot4AddU8Packed                                unsigned dot product of 4 x u8 vectors packed into i32, with accumulate to i32
+165 WaveMatch                                      returns the bitmask of active lanes that have the same value
+166 WaveMultiPrefixOp                              returns the result of the operation on groups of lanes identified by a bitmask
+167 WaveMultiPrefixBitCount                        returns the count of bits set to 1 on groups of lanes identified by a bitmask
+168 SetMeshOutputCounts                            Mesh shader intrinsic SetMeshOutputCounts
+169 EmitIndices                                    emit a primitive's vertex indices in a mesh shader
+170 GetMeshPayload                                 get the mesh payload which is from amplification shader
+171 StoreVertexOutput                              stores the value to mesh shader vertex output
+172 StorePrimitiveOutput                           stores the value to mesh shader primitive output
+173 DispatchMesh                                   Amplification shader intrinsic DispatchMesh
+174 WriteSamplerFeedback                           updates a feedback texture for a sampling operation
+175 WriteSamplerFeedbackBias                       updates a feedback texture for a sampling operation with a bias on the mipmap level
+176 WriteSamplerFeedbackLevel                      updates a feedback texture for a sampling operation with a mipmap-level offset
+177 WriteSamplerFeedbackGrad                       updates a feedback texture for a sampling operation with explicit gradients
+178 AllocateRayQuery                               allocates space for RayQuery and return handle
+179 RayQuery_TraceRayInline                        initializes RayQuery for raytrace
+180 RayQuery_Proceed                               advances a ray query
+181 RayQuery_Abort                                 aborts a ray query
+182 RayQuery_CommitNonOpaqueTriangleHit            commits a non opaque triangle hit
+183 RayQuery_CommitProceduralPrimitiveHit          commits a procedural primitive hit
+184 RayQuery_CommittedStatus                       returns uint status (COMMITTED_STATUS) of the committed hit in a ray query
+185 RayQuery_CandidateType                         returns uint candidate type (CANDIDATE_TYPE) of the current hit candidate in a ray query, after Proceed() has returned true
+186 RayQuery_CandidateObjectToWorld3x4             returns matrix for transforming from object-space to world-space for a candidate hit.
+187 RayQuery_CandidateWorldToObject3x4             returns matrix for transforming from world-space to object-space for a candidate hit.
+188 RayQuery_CommittedObjectToWorld3x4             returns matrix for transforming from object-space to world-space for a Committed hit.
+189 RayQuery_CommittedWorldToObject3x4             returns matrix for transforming from world-space to object-space for a Committed hit.
+190 RayQuery_CandidateProceduralPrimitiveNonOpaque returns if current candidate procedural primitive is non opaque
+191 RayQuery_CandidateTriangleFrontFace            returns if current candidate triangle is front facing
+192 RayQuery_CommittedTriangleFrontFace            returns if current committed triangle is front facing
+193 RayQuery_CandidateTriangleBarycentrics         returns candidate triangle hit barycentrics
+194 RayQuery_CommittedTriangleBarycentrics         returns committed triangle hit barycentrics
+195 RayQuery_RayFlags                              returns ray flags
+196 RayQuery_WorldRayOrigin                        returns world ray origin
+197 RayQuery_WorldRayDirection                     returns world ray direction
+198 RayQuery_RayTMin                               returns float representing the parametric starting point for the ray.
+199 RayQuery_CandidateTriangleRayT                 returns float representing the parametric point on the ray for the current candidate triangle hit.
+200 RayQuery_CommittedRayT                         returns float representing the parametric point on the ray for the current committed hit.
+201 RayQuery_CandidateInstanceIndex                returns candidate hit instance index
+202 RayQuery_CandidateInstanceID                   returns candidate hit instance ID
+203 RayQuery_CandidateGeometryIndex                returns candidate hit geometry index
+204 RayQuery_CandidatePrimitiveIndex               returns candidate hit geometry index
+205 RayQuery_CandidateObjectRayOrigin              returns candidate hit object ray origin
+206 RayQuery_CandidateObjectRayDirection           returns candidate object ray direction
+207 RayQuery_CommittedInstanceIndex                returns committed hit instance index
+208 RayQuery_CommittedInstanceID                   returns committed hit instance ID
+209 RayQuery_CommittedGeometryIndex                returns committed hit geometry index
+210 RayQuery_CommittedPrimitiveIndex               returns committed hit geometry index
+211 RayQuery_CommittedObjectRayOrigin              returns committed hit object ray origin
+212 RayQuery_CommittedObjectRayDirection           returns committed object ray direction
+213 GeometryIndex                                  The autogenerated index of the current geometry in the bottom-level structure
+=== ============================================== =======================================================================================================================================================================================================================
 
 
 Acos
@@ -2957,13 +3010,19 @@ INSTR.MINPRECISIONNOTPRECISE             Instructions marked precise may not ref
 INSTR.MINPRECISONBITCAST                 Bitcast on minprecison types is not allowed
 INSTR.MIPLEVELFORGETDIMENSION            Use mip level on buffer when GetDimensions
 INSTR.MIPONUAVLOAD                       uav load don't support mipLevel/sampleIndex
+INSTR.MISSINGSETMESHOUTPUTCOUNTS         Missing SetMeshOutputCounts call.
+INSTR.MULTIPLEGETMESHPAYLOAD             GetMeshPayload cannot be called multiple times.
+INSTR.MULTIPLESETMESHOUTPUTCOUNTS        SetMeshOUtputCounts cannot be called multiple times.
 INSTR.NOGENERICPTRADDRSPACECAST          Address space cast between pointer types must have one part to be generic address space
 INSTR.NOIDIVBYZERO                       No signed integer division by zero
 INSTR.NOINDEFINITEACOS                   No indefinite arccosine
 INSTR.NOINDEFINITEASIN                   No indefinite arcsine
 INSTR.NOINDEFINITEDSXY                   No indefinite derivative calculation
 INSTR.NOINDEFINITELOG                    No indefinite logarithm
+INSTR.NONDOMINATINGDISPATCHMESH          Non-Dominating DispatchMesh call.
+INSTR.NONDOMINATINGSETMESHOUTPUTCOUNTS   Non-Dominating SetMeshOutputCounts call.
 INSTR.NOREADINGUNINITIALIZED             Instructions should not read uninitialized value
+INSTR.NOTONCEDISPATCHMESH                DispatchMesh must be called exactly once in an Amplification shader.
 INSTR.NOUDIVBYZERO                       No unsigned integer division by zero
 INSTR.OFFSETONUAVLOAD                    uav load don't support offset
 INSTR.OLOAD                              DXIL intrinsic overload must be valid
@@ -3052,12 +3111,14 @@ META.VALIDSAMPLERMODE                    Invalid sampler mode on sampler
 META.VALUERANGE                          Metadata value must be within range
 META.WELLFORMED                          TODO - Metadata must be well-formed in operand count and types
 SM.64BITRAWBUFFERLOADSTORE               i64/f64 rawBufferLoad/Store overloads are allowed after SM 6.3
+SM.AMPLIFICATIONSHADERPAYLOADSIZE        For shader '%0', payload size is greater than %1
 SM.APPENDANDCONSUMEONSAMEUAV             BufferUpdateCounter inc and dec on a given UAV (%d) cannot both be in the same shader for shader model less than 5.1.
 SM.CBUFFERARRAYOFFSETALIGNMENT           CBuffer array offset must be aligned to 16-bytes
 SM.CBUFFERELEMENTOVERFLOW                CBuffer elements must not overflow
 SM.CBUFFEROFFSETOVERLAP                  CBuffer offsets must not overlap
 SM.CBUFFERTEMPLATETYPEMUSTBESTRUCT       D3D12 constant/texture buffer template element can only be a struct
 SM.COMPLETEPOSITION                      Not all elements of SV_Position were written
+SM.CONSTANTINTERPMODE                    Interpolation mode must be constant for MS primitive output.
 SM.COUNTERONLYONSTRUCTBUF                BufferUpdateCounter valid only on structured buffers
 SM.CSNOSIGNATURES                        Compute shaders must not have shader signatures.
 SM.DOMAINLOCATIONIDXOOB                  DomainLocation component index out of bounds for the domain.
@@ -3075,8 +3136,17 @@ SM.INVALIDRESOURCECOMPTYPE               Invalid resource return type
 SM.INVALIDRESOURCEKIND                   Invalid resources kind
 SM.INVALIDTEXTUREKINDONUAV               Texture2DMS[Array] or TextureCube[Array] resources are not supported with UAVs
 SM.ISOLINEOUTPUTPRIMITIVEMISMATCH        Hull Shader declared with IsoLine Domain must specify output primitive point or line. Triangle_cw or triangle_ccw output are not compatible with the IsoLine Domain.
+SM.MAXMSSMSIZE                           Total Thread Group Shared Memory storage is %0, exceeded %1
 SM.MAXTGSMSIZE                           Total Thread Group Shared Memory storage is %0, exceeded %1
 SM.MAXTHEADGROUP                         Declared Thread Group Count %0 (X*Y*Z) is beyond the valid maximum of %1
+SM.MESHPSIGROWCOUNT                      For shader '%0', primitive output signatures are taking up more than %1 rows
+SM.MESHSHADERINOUTSIZE                   For shader '%0', input plus output size is greater than %1
+SM.MESHSHADERMAXPRIMITIVECOUNT           MS max primitive output count must be [0..%0].  %1 specified
+SM.MESHSHADERMAXVERTEXCOUNT              MS max vertex output count must be [0..%0].  %1 specified
+SM.MESHSHADEROUTPUTSIZE                  For shader '%0', vertex plus primitive output size is greater than %1
+SM.MESHSHADERPAYLOADSIZE                 For shader '%0', payload size is greater than %1
+SM.MESHTOTALSIGROWCOUNT                  For shader '%0', vertex and primitive output signatures are taking up more than %1 rows
+SM.MESHVSIGROWCOUNT                      For shader '%0', vertex output signatures are taking up more than %1 rows
 SM.MULTISTREAMMUSTBEPOINT                When multiple GS output streams are used they must be pointlists
 SM.NAME                                  Target shader model name must be known
 SM.NOINTERPMODE                          Interpolation mode must be undefined for VS input/PS output/patch constant.

+ 124 - 18
docs/SPIR-V.rst

@@ -264,6 +264,7 @@ Right now the following ``<builtin>`` are supported:
   Need ``SPV_KHR_shader_draw_parameters`` extension.
 * ``DeviceIndex``: The GLSL equivalent is ``gl_DeviceIndex``.
   Need ``SPV_KHR_device_group`` extension.
+* ``ViewportMaskNV``: The GLSL equivalent is ``gl_ViewportMask``.
 
 Please see Vulkan spec. `14.6. Built-In Variables <https://www.khronos.org/registry/vulkan/specs/1.1-extensions/html/vkspec.html#interfaces-builtin-variables>`_
 for detailed explanation of these builtins.
@@ -282,6 +283,7 @@ Supported extensions
 * SPV_EXT_shader_stencil_support
 * SPV_AMD_shader_explicit_vertex_parameter
 * SPV_GOOGLE_hlsl_functionality1
+* SPV_NV_mesh_shader
 
 Vulkan specific attributes
 --------------------------
@@ -1272,14 +1274,16 @@ some system-value (SV) semantic strings will be translated into SPIR-V
 |                           | HSCPOut     | ``Position``             | N/A                   | ``Shader``                  |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | DSCPIn      | ``Position``             | N/A                   | ``Shader``                  |
-| SV_Position               +-------------+--------------------------+-----------------------+-----------------------------+
-|                           | DSOut       | ``Position``             | N/A                   | ``Shader``                  |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+| SV_Position               | DSOut       | ``Position``             | N/A                   | ``Shader``                  |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | GSVIn       | ``Position``             | N/A                   | ``Shader``                  |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | GSOut       | ``Position``             | N/A                   | ``Shader``                  |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | PSIn        | ``FragCoord``            | N/A                   | ``Shader``                  |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+|                           | MSOut       | ``Position``             | N/A                   | ``Shader``                  |
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
 |                           | VSOut       | ``ClipDistance``         | N/A                   | ``ClipDistance``            |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
@@ -1288,14 +1292,16 @@ some system-value (SV) semantic strings will be translated into SPIR-V
 |                           | HSCPOut     | ``ClipDistance``         | N/A                   | ``ClipDistance``            |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | DSCPIn      | ``ClipDistance``         | N/A                   | ``ClipDistance``            |
-| SV_ClipDistance           +-------------+--------------------------+-----------------------+-----------------------------+
-|                           | DSOut       | ``ClipDistance``         | N/A                   | ``ClipDistance``            |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+| SV_ClipDistance           | DSOut       | ``ClipDistance``         | N/A                   | ``ClipDistance``            |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | GSVIn       | ``ClipDistance``         | N/A                   | ``ClipDistance``            |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | GSOut       | ``ClipDistance``         | N/A                   | ``ClipDistance``            |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | PSIn        | ``ClipDistance``         | N/A                   | ``ClipDistance``            |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+|                           | MSOut       | ``ClipDistance``         | N/A                   | ``ClipDistance``            |
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
 |                           | VSOut       | ``CullDistance``         | N/A                   | ``CullDistance``            |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
@@ -1304,14 +1310,16 @@ some system-value (SV) semantic strings will be translated into SPIR-V
 |                           | HSCPOut     | ``CullDistance``         | N/A                   | ``CullDistance``            |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | DSCPIn      | ``CullDistance``         | N/A                   | ``CullDistance``            |
-| SV_CullDistance           +-------------+--------------------------+-----------------------+-----------------------------+
-|                           | DSOut       | ``CullDistance``         | N/A                   | ``CullDistance``            |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+| SV_CullDistance           | DSOut       | ``CullDistance``         | N/A                   | ``CullDistance``            |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | GSVIn       | ``CullDistance``         | N/A                   | ``CullDistance``            |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | GSOut       | ``CullDistance``         | N/A                   | ``CullDistance``            |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | PSIn        | ``CullDistance``         | N/A                   | ``CullDistance``            |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+|                           | MSOut       | ``CullDistance``         | N/A                   | ``CullDistance``            |
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
 | SV_VertexID               | VSIn        | ``VertexIndex``          | N/A                   | ``Shader``                  |
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
@@ -1325,13 +1333,29 @@ some system-value (SV) semantic strings will be translated into SPIR-V
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
 | SV_IsFrontFace            | PSIn        | ``FrontFacing``          | N/A                   | ``Shader``                  |
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
-| SV_DispatchThreadID       | CSIn        | ``GlobalInvocationId``   | N/A                   | ``Shader``                  |
+|                           | CSIn        | ``GlobalInvocationId``   | N/A                   | ``Shader``                  |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+| SV_DispatchThreadID       | MSIn        | ``GlobalInvocationId``   | N/A                   | ``Shader``                  |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+|                           | ASIn        | ``GlobalInvocationId``   | N/A                   | ``Shader``                  |
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
-| SV_GroupID                | CSIn        | ``WorkgroupId``          | N/A                   | ``Shader``                  |
+|                           | CSIn        | ``WorkgroupId``          | N/A                   | ``Shader``                  |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+| SV_GroupID                | MSIn        | ``WorkgroupId``          | N/A                   | ``Shader``                  |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+|                           | ASIn        | ``WorkgroupId``          | N/A                   | ``Shader``                  |
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
-| SV_GroupThreadID          | CSIn        | ``LocalInvocationId``    | N/A                   | ``Shader``                  |
+|                           | CSIn        | ``LocalInvocationId``    | N/A                   | ``Shader``                  |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+| SV_GroupThreadID          | MSIn        | ``LocalInvocationId``    | N/A                   | ``Shader``                  |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+|                           | ASIn        | ``LocalInvocationId``    | N/A                   | ``Shader``                  |
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
-| SV_GroupIndex             | CSIn        | ``LocalInvocationIndex`` | N/A                   | ``Shader``                  |
+|                           | CSIn        | ``LocalInvocationIndex`` | N/A                   | ``Shader``                  |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+| SV_GroupIndex             | MSIn        | ``LocalInvocationIndex`` | N/A                   | ``Shader``                  |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+|                           | ASIn        | ``LocalInvocationIndex`` | N/A                   | ``Shader``                  |
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
 | SV_OutputControlPointID   | HSIn        | ``InvocationId``         | N/A                   | ``Tessellation``            |
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
@@ -1344,12 +1368,14 @@ some system-value (SV) semantic strings will be translated into SPIR-V
 |                           | PCIn        | ``PrimitiveId``          | N/A                   | ``Tessellation``            |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | DsIn        | ``PrimitiveId``          | N/A                   | ``Tessellation``            |
-| SV_PrimitiveID            +-------------+--------------------------+-----------------------+-----------------------------+
-|                           | GSIn        | ``PrimitiveId``          | N/A                   | ``Geometry``                |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+| SV_PrimitiveID            | GSIn        | ``PrimitiveId``          | N/A                   | ``Geometry``                |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | GSOut       | ``PrimitiveId``          | N/A                   | ``Geometry``                |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | PSIn        | ``PrimitiveId``          | N/A                   | ``Geometry``                |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+|                           | MSOut       | ``PrimitiveId``          | N/A                   | ``MeshShadingNV``           |
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
 |                           | PCOut       | ``TessLevelOuter``       | N/A                   | ``Tessellation``            |
 | SV_TessFactor             +-------------+--------------------------+-----------------------+-----------------------------+
@@ -1366,12 +1392,16 @@ some system-value (SV) semantic strings will be translated into SPIR-V
 | SV_Barycentrics           | PSIn        | ``BaryCoord*AMD``        | N/A                   | ``Shader``                  |
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
 |                           | GSOut       | ``Layer``                | N/A                   | ``Geometry``                |
-| SV_RenderTargetArrayIndex +-------------+--------------------------+-----------------------+-----------------------------+
-|                           | PSIn        | ``Layer``                | N/A                   | ``Geometry``                |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+| SV_RenderTargetArrayIndex | PSIn        | ``Layer``                | N/A                   | ``Geometry``                |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+|                           | MSOut       | ``Layer``                | N/A                   | ``MeshShadingNV``           |
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
 |                           | GSOut       | ``ViewportIndex``        | N/A                   | ``MultiViewport``           |
-| SV_ViewportArrayIndex     +-------------+--------------------------+-----------------------+-----------------------------+
-|                           | PSIn        | ``ViewportIndex``        | N/A                   | ``MultiViewport``           |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+| SV_ViewportArrayIndex     | PSIn        | ``ViewportIndex``        | N/A                   | ``MultiViewport``           |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+|                           | MSOut       | ``ViewportIndex``        | N/A                   | ``MeshShadingNV``           |
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
 |                           | PSIn        | ``SampleMask``           | N/A                   | ``Shader``                  |
 | SV_Coverage               +-------------+--------------------------+-----------------------+-----------------------------+
@@ -1383,11 +1413,13 @@ some system-value (SV) semantic strings will be translated into SPIR-V
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | HSIn        | ``ViewIndex``            | N/A                   | ``MultiView``               |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
-| SV_ViewID                 | DSIn        | ``ViewIndex``            | N/A                   | ``MultiView``               |
-|                           +-------------+--------------------------+-----------------------+-----------------------------+
+|                           | DSIn        | ``ViewIndex``            | N/A                   | ``MultiView``               |
+| SV_ViewID                 +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | GSIn        | ``ViewIndex``            | N/A                   | ``MultiView``               |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | PSIn        | ``ViewIndex``            | N/A                   | ``MultiView``               |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+|                           | MSIn        | ``ViewIndex``            | N/A                   | ``MultiView``               |
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
 | SV_ShadingRate            | PSIn        | ``FragSizeEXT``          | N/A                   | ``FragmentDensityEXT``      |
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
@@ -3069,6 +3101,80 @@ Callable Stage
     a.color = float4(0.0f,1.0f,0.0f,0.0f);
   }
 
+Mesh and Amplification Shaders
+------------------------------
+
+DirectX adds 2 new shader stages for using MeshShading pipeline namely Mesh and Amplification.
+Amplification shaders corresponds to Task Shaders in Vulkan.
+
+| Refer to following HLSL and SPIR-V specs for details:
+| https://docs.microsoft.com/<TBD>
+| https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/NV/SPV_NV_mesh_shader.asciidoc
+
+This section describes how Mesh and Amplification shaders are translated to SPIR-V for Vulkan.
+
+Entry Point Attributes
+~~~~~~~~~~~~~~~~~~~~~~
+The following HLSL attributes are attached to the main entry point of Mesh and/or Amplification
+shaders and are translated to SPIR-V execution modes according to the table below:
+
+.. table:: Mapping from HLSL attribute to SPIR-V execution mode
+
++--------------------+----------------+-------------------------+
+|  HLSL Attribute    |   Value        | SPIR-V Execution Mode   |
++====================+================+=========================+
+|                    | ``point``      | ``OutputPoints``        |
+|                    +----------------+-------------------------+
+| ``outputtopology`` | ``line``       | ``OutputLinesNV``       |
+|   (Mesh shader)    +----------------+-------------------------+
+|                    | ``triangle``   | ``OutputTrianglesNV``   |
++--------------------+----------------+-------------------------+
+| ``numthreads``     | ``X, Y, Z``    | ``LocalSize X, Y, Z``   |
+|                    | (X*Y*Z <= 128) |                         |
++--------------------+----------------+-------------------------+
+
+Intrinsics
+~~~~~~~~~~
+The following HLSL intrinsics are used in Mesh or Amplification shaders
+and are translated to SPIR-V intrinsics according to the table below:
+
+.. table:: Mapping from HLSL intrinsics to SPIR-V intrinsics
+
++-------------------------+--------------------+-----------------------------------------+
+|  HLSL Intrinsic         |  Parameters        | SPIR-V Intrinsic                        |
++=========================+====================+=========================================+
+| ``SetMeshOutputCounts`` | ``numVertices``    | ``PrimitiveCountNV numPrimitives``      |
+|     (Mesh shader)       | ``numPrimitives``  |                                         |
++-------------------------+--------------------+-----------------------------------------+
+|                         | ``ThreadX``        |                                         |
+| ``DispatchMesh``        | ``ThreadY``        |  ``OpControlBarrier``                   |
+| (Amplification shader)  | ``ThreadZ``        | ``TaskCountNV ThreadX*ThreadY*ThreadZ`` |
+|                         | ``MeshPayload``    |                                         |
++-------------------------+--------------------+-----------------------------------------+
+
+| *For DispatchMesh intrinsic, we also emit MeshPayload as output block with PerTaskNV decoration
+
+Mesh Interface Variables
+~~~~~~~~~~~~~~~~~~~~~~~~
+Interface variables are defined for Mesh shaders using HLSL modifiers.
+Following table gives high level overview of the mapping:
+
+.. table:: Mapping from HLSL modifiers to SPIR-V definitions
+
++-----------------+-------------------------------------------------------------------------+
+|  HLSL modifier  | SPIR-V definition                                                       |
++=================+=========================================================================+
+| ``indices``     | Maps to SPIR-V intrinsic ``PrimitiveIndicesNV``                         |  
+|                 | Defines SPIR-V Execution Mode ``OutputPrimitivesNV <array-size>``       |
++-----------------+-------------------------------------------------------------------------+
+| ``vertices``    | Maps to per-vertex out attributes                                       |
+|                 | Defines existing SPIR-V Execution Mode ``OutputVertices <array-size>``  |
++-----------------+-------------------------------------------------------------------------+
+| ``primitives``  | Maps to per-primitive out attributes with ``PerPrimitiveNV`` decoration |
++-----------------+-------------------------------------------------------------------------+
+| ``payload``     | Maps to per-task in attributes with ``PerTaskNV`` decoration            |
++-----------------+-------------------------------------------------------------------------+
+
 
 Raytracing in Vulkan and SPIRV
 ==============================

+ 1 - 1
external/SPIRV-Headers

@@ -1 +1 @@
-Subproject commit de99d4d834aeb51dd9f099baa285bd44fd04bb3d
+Subproject commit 29c11140baaf9f7fdaa39a583672c556bf1795a1

+ 1 - 1
external/SPIRV-Tools

@@ -1 +1 @@
-Subproject commit df86bb44fe476515f9a298bacd8e1d4a3522e989
+Subproject commit 5081512502df32b38a38712adb0d8c1b23bb1c2e

+ 161 - 14
include/dxc/DXIL/DxilConstants.h

@@ -68,17 +68,33 @@ namespace DXIL {
   const float kHSMaxTessFactorUpperBound = 64.0f;
   const unsigned kHSDefaultInputControlPointCount = 1;
   const unsigned kMaxCSThreadsPerGroup = 1024;
-  const unsigned kMaxCSThreadGroupX	= 1024;
-  const unsigned kMaxCSThreadGroupY	= 1024;
+  const unsigned kMaxCSThreadGroupX = 1024;
+  const unsigned kMaxCSThreadGroupY = 1024;
   const unsigned kMaxCSThreadGroupZ = 64;
   const unsigned kMinCSThreadGroupX = 1;
   const unsigned kMinCSThreadGroupY = 1;
   const unsigned kMinCSThreadGroupZ = 1;
   const unsigned kMaxCS4XThreadsPerGroup = 768;
-  const unsigned kMaxCS4XThreadGroupX	= 768;
-  const unsigned kMaxCS4XThreadGroupY	= 768;
+  const unsigned kMaxCS4XThreadGroupX = 768;
+  const unsigned kMaxCS4XThreadGroupY = 768;
   const unsigned kMaxTGSMSize = 8192*4;
   const unsigned kMaxGSOutputTotalScalars = 1024;
+  const unsigned kMaxMSASThreadsPerGroup = 128;
+  const unsigned kMaxMSASThreadGroupX = 128;
+  const unsigned kMaxMSASThreadGroupY = 128;
+  const unsigned kMaxMSASThreadGroupZ = 128;
+  const unsigned kMinMSASThreadGroupX = 1;
+  const unsigned kMinMSASThreadGroupY = 1;
+  const unsigned kMinMSASThreadGroupZ = 1;
+  const unsigned kMaxMSASPayloadSize = 16384;
+  const unsigned kMaxMSOutputPrimitiveCount = 256;
+  const unsigned kMaxMSOutputVertexCount = 256;
+  const unsigned kMaxMSOutputTotalScalars = 32768;
+  const unsigned kMaxMSInputOutputTotalScalars = 41984;
+  const unsigned kMaxMSVSigRows = 32;
+  const unsigned kMaxMSPSigRows = 32;
+  const unsigned kMaxMSTotalSigRows = 32;
+  const unsigned kMaxMSSMSize = 1024 * 28;
 
   const float kMaxMipLodBias = 15.99f;
   const float kMinMipLodBias = -16.0f;
@@ -116,7 +132,7 @@ namespace DXIL {
     Invalid = 0,
     Input,
     Output,
-    PatchConstant,
+    PatchConstOrPrim,
   };
 
   // Must match D3D11_SHADER_VERSION_TYPE
@@ -134,6 +150,8 @@ namespace DXIL {
     ClosestHit,
     Miss,
     Callable,
+    Mesh,
+    Amplification,
     Invalid,
   };
 
@@ -171,6 +189,7 @@ namespace DXIL {
     ViewID,
     Barycentrics,
     ShadingRate,
+    CullPrimitive,
     Invalid,
   };
   // SemanticKind-ENUM:END
@@ -195,6 +214,10 @@ namespace DXIL {
     PSIn, // Pixel Shader input
     PSOut, // Pixel Shader output
     CSIn, // Compute Shader input
+    MSIn, // Mesh Shader input
+    MSOut, // Mesh Shader vertices output
+    MSPOut, // Mesh Shader primitives output
+    ASIn, // Amplification Shader input
     Invalid,
   };
   // SigPointKind-ENUM:END
@@ -291,6 +314,10 @@ namespace DXIL {
     Sampler,
     TBuffer,
     RTAccelerationStructure,
+    FeedbackTexture2DMinLOD,
+    FeedbackTexture2DTiled,
+    FeedbackTexture2DArrayMinLOD,
+    FeedbackTexture2DArrayTiled,
     NumEntries,
   };
 
@@ -299,6 +326,9 @@ namespace DXIL {
   // OPCODE-ENUM:BEGIN
   // Enumeration for operations specified by DXIL
   enum class OpCode : unsigned {
+    // Amplification shader instructions
+    DispatchMesh = 173, // Amplification shader intrinsic DispatchMesh
+  
     // AnyHit Terminals
     AcceptHitAndEndSearch = 156, // Used in an any hit shader to abort the ray query and the intersection shader (if any). The current hit is committed and execution passes to the closest hit shader with the closest hit recorded so far
     IgnoreHit = 155, // Used in an any hit shader to reject an intersection and terminate the shader
@@ -334,7 +364,7 @@ namespace DXIL {
     BitcastI32toF32 = 126, // bitcast between different sizes
     BitcastI64toF64 = 128, // bitcast between different sizes
   
-    // Compute shader
+    // Compute/Mesh/Amplification shader
     FlattenedThreadIdInGroup = 96, // provides a flattened index for a given thread within a given group (SV_GroupIndex)
     GroupId = 94, // reads the group ID (SV_GroupID)
     ThreadId = 93, // reads the thread ID
@@ -383,7 +413,44 @@ namespace DXIL {
     // Indirect Shader Invocation
     CallShader = 159, // Call a shader in the callable shader table supplied through the DispatchRays() API
     ReportHit = 158, // returns true if hit was accepted
-    TraceRay = 157, // returns the view index
+    TraceRay = 157, // initiates raytrace
+  
+    // Inline Ray Query
+    AllocateRayQuery = 178, // allocates space for RayQuery and return handle
+    RayQuery_Abort = 181, // aborts a ray query
+    RayQuery_CandidateGeometryIndex = 203, // returns candidate hit geometry index
+    RayQuery_CandidateInstanceID = 202, // returns candidate hit instance ID
+    RayQuery_CandidateInstanceIndex = 201, // returns candidate hit instance index
+    RayQuery_CandidateObjectRayDirection = 206, // returns candidate object ray direction
+    RayQuery_CandidateObjectRayOrigin = 205, // returns candidate hit object ray origin
+    RayQuery_CandidateObjectToWorld3x4 = 186, // returns matrix for transforming from object-space to world-space for a candidate hit.
+    RayQuery_CandidatePrimitiveIndex = 204, // returns candidate hit geometry index
+    RayQuery_CandidateProceduralPrimitiveNonOpaque = 190, // returns if current candidate procedural primitive is non opaque
+    RayQuery_CandidateTriangleBarycentrics = 193, // returns candidate triangle hit barycentrics
+    RayQuery_CandidateTriangleFrontFace = 191, // returns if current candidate triangle is front facing
+    RayQuery_CandidateTriangleRayT = 199, // returns float representing the parametric point on the ray for the current candidate triangle hit.
+    RayQuery_CandidateType = 185, // returns uint candidate type (CANDIDATE_TYPE) of the current hit candidate in a ray query, after Proceed() has returned true
+    RayQuery_CandidateWorldToObject3x4 = 187, // returns matrix for transforming from world-space to object-space for a candidate hit.
+    RayQuery_CommitNonOpaqueTriangleHit = 182, // commits a non opaque triangle hit
+    RayQuery_CommitProceduralPrimitiveHit = 183, // commits a procedural primitive hit
+    RayQuery_CommittedGeometryIndex = 209, // returns committed hit geometry index
+    RayQuery_CommittedInstanceID = 208, // returns committed hit instance ID
+    RayQuery_CommittedInstanceIndex = 207, // returns committed hit instance index
+    RayQuery_CommittedObjectRayDirection = 212, // returns committed object ray direction
+    RayQuery_CommittedObjectRayOrigin = 211, // returns committed hit object ray origin
+    RayQuery_CommittedObjectToWorld3x4 = 188, // returns matrix for transforming from object-space to world-space for a Committed hit.
+    RayQuery_CommittedPrimitiveIndex = 210, // returns committed hit geometry index
+    RayQuery_CommittedRayT = 200, // returns float representing the parametric point on the ray for the current committed hit.
+    RayQuery_CommittedStatus = 184, // returns uint status (COMMITTED_STATUS) of the committed hit in a ray query
+    RayQuery_CommittedTriangleBarycentrics = 194, // returns committed triangle hit barycentrics
+    RayQuery_CommittedTriangleFrontFace = 192, // returns if current committed triangle is front facing
+    RayQuery_CommittedWorldToObject3x4 = 189, // returns matrix for transforming from world-space to object-space for a Committed hit.
+    RayQuery_Proceed = 180, // advances a ray query
+    RayQuery_RayFlags = 195, // returns ray flags
+    RayQuery_RayTMin = 198, // returns float representing the parametric starting point for the ray.
+    RayQuery_TraceRayInline = 179, // initializes RayQuery for raytrace
+    RayQuery_WorldRayDirection = 197, // returns world ray direction
+    RayQuery_WorldRayOrigin = 196, // returns world ray origin
   
     // Legacy floating-point
     LegacyF16ToF32 = 131, // legacy fuction to convert half (f16) to float (f32) (this is not related to min-precision)
@@ -392,6 +459,13 @@ namespace DXIL {
     // Library create handle from resource struct (like HL intrinsic)
     CreateHandleForLib = 160, // create resource handle from resource struct for library
   
+    // Mesh shader instructions
+    EmitIndices = 169, // emit a primitive's vertex indices in a mesh shader
+    GetMeshPayload = 170, // get the mesh payload which is from amplification shader
+    SetMeshOutputCounts = 168, // Mesh shader intrinsic SetMeshOutputCounts
+    StorePrimitiveOutput = 172, // stores the value to mesh shader primitive output
+    StoreVertexOutput = 171, // stores the value to mesh shader vertex output
+  
     // Other
     CycleCounterLegacy = 109, // CycleCounterLegacy
   
@@ -436,6 +510,9 @@ namespace DXIL {
     // Raytracing hit uint System Values
     HitKind = 143, // Returns the value passed as HitKind in ReportIntersection().  If intersection was reported by fixed-function triangle intersection, HitKind will be one of HIT_KIND_TRIANGLE_FRONT_FACE or HIT_KIND_TRIANGLE_BACK_FACE.
   
+    // Raytracing object space uint System Values, raytracing tier 1.1
+    GeometryIndex = 213, // The autogenerated index of the current geometry in the bottom-level structure
+  
     // Raytracing object space uint System Values
     InstanceID = 141, // The user-provided InstanceID on the bottom-level acceleration structure instance within the top-level structure
     InstanceIndex = 142, // The autogenerated index of the current instance in the top-level structure
@@ -473,6 +550,12 @@ namespace DXIL {
     TextureLoad = 66, // reads texel data without any filtering or sampling
     TextureStore = 67, // reads texel data without any filtering or sampling
   
+    // Sampler Feedback
+    WriteSamplerFeedback = 174, // updates a feedback texture for a sampling operation
+    WriteSamplerFeedbackBias = 175, // updates a feedback texture for a sampling operation with a bias on the mipmap level
+    WriteSamplerFeedbackGrad = 177, // updates a feedback texture for a sampling operation with explicit gradients
+    WriteSamplerFeedbackLevel = 176, // updates a feedback texture for a sampling operation with a mipmap-level offset
+  
     // Synchronization
     AtomicBinOp = 78, // performs an atomic operation on two operands
     AtomicCompareExchange = 79, // atomic compare and exchange to memory
@@ -562,9 +645,9 @@ namespace DXIL {
     NumOpCodes_Dxil_1_2 = 141,
     NumOpCodes_Dxil_1_3 = 162,
     NumOpCodes_Dxil_1_4 = 165,
-    NumOpCodes_Dxil_1_5 = 168,
+    NumOpCodes_Dxil_1_5 = 214,
   
-    NumOpCodes = 168 // exclusive last value of enumeration
+    NumOpCodes = 214 // exclusive last value of enumeration
   };
   // OPCODE-ENUM:END
 
@@ -572,6 +655,9 @@ namespace DXIL {
   // OPCODECLASS-ENUM:BEGIN
   // Groups for DXIL operations with equivalent function templates
   enum class OpCodeClass : unsigned {
+    // Amplification shader instructions
+    DispatchMesh,
+  
     // AnyHit Terminals
     AcceptHitAndEndSearch,
     IgnoreHit,
@@ -593,7 +679,7 @@ namespace DXIL {
     BitcastI32toF32,
     BitcastI64toF64,
   
-    // Compute shader
+    // Compute/Mesh/Amplification shader
     FlattenedThreadIdInGroup,
     GroupId,
     ThreadId,
@@ -643,6 +729,17 @@ namespace DXIL {
     ReportHit,
     TraceRay,
   
+    // Inline Ray Query
+    AllocateRayQuery,
+    RayQuery_Abort,
+    RayQuery_CommitNonOpaqueTriangleHit,
+    RayQuery_CommitProceduralPrimitiveHit,
+    RayQuery_Proceed,
+    RayQuery_StateMatrix,
+    RayQuery_StateScalar,
+    RayQuery_StateVector,
+    RayQuery_TraceRayInline,
+  
     // LLVM Instructions
     LlvmInst,
   
@@ -653,6 +750,13 @@ namespace DXIL {
     // Library create handle from resource struct (like HL intrinsic)
     CreateHandleForLib,
   
+    // Mesh shader instructions
+    EmitIndices,
+    GetMeshPayload,
+    SetMeshOutputCounts,
+    StorePrimitiveOutput,
+    StoreVertexOutput,
+  
     // Other
     CycleCounterLegacy,
   
@@ -694,6 +798,9 @@ namespace DXIL {
     // Raytracing hit uint System Values
     HitKind,
   
+    // Raytracing object space uint System Values, raytracing tier 1.1
+    GeometryIndex,
+  
     // Raytracing object space uint System Values
     InstanceID,
     InstanceIndex,
@@ -731,6 +838,12 @@ namespace DXIL {
     TextureLoad,
     TextureStore,
   
+    // Sampler Feedback
+    WriteSamplerFeedback,
+    WriteSamplerFeedbackBias,
+    WriteSamplerFeedbackGrad,
+    WriteSamplerFeedbackLevel,
+  
     // Synchronization
     AtomicBinOp,
     AtomicCompareExchange,
@@ -778,9 +891,9 @@ namespace DXIL {
     NumOpClasses_Dxil_1_2 = 97,
     NumOpClasses_Dxil_1_3 = 118,
     NumOpClasses_Dxil_1_4 = 120,
-    NumOpClasses_Dxil_1_5 = 123,
+    NumOpClasses_Dxil_1_5 = 143,
   
-    NumOpClasses = 123 // exclusive last value of enumeration
+    NumOpClasses = 143 // exclusive last value of enumeration
   };
   // OPCODECLASS-ENUM:END
 
@@ -807,11 +920,12 @@ namespace DXIL {
     const unsigned kLoadInputColOpIdx = 3;
     const unsigned kLoadInputVertexIDOpIdx = 4;
 
-    // StoreOutput.
+    // StoreOutput, StoreVertexOutput, StorePrimitiveOutput
     const unsigned kStoreOutputIDOpIdx = 1;
     const unsigned kStoreOutputRowOpIdx = 2;
     const unsigned kStoreOutputColOpIdx = 3;
     const unsigned kStoreOutputValOpIdx = 4;
+    const unsigned kStoreOutputVPIDOpIdx = 5;
 
     // DomainLocation.
     const unsigned kDomainLocationColOpIdx = 1;
@@ -910,9 +1024,20 @@ namespace DXIL {
     const unsigned kTraceRayPayloadOpIdx = 15;
     const unsigned kTraceRayNumOp = 16;
 
+    // TraceRayInline
+    const unsigned kTraceRayInlineRayDescOpIdx = 5;
+    const unsigned kTraceRayInlineNumOp = 13;
 
     // Emit/Cut
     const unsigned kStreamEmitCutIDOpIdx = 1;
+
+    // StoreVectorOutput/StorePrimitiveOutput.
+    const unsigned kMSStoreOutputIDOpIdx = 1;
+    const unsigned kMSStoreOutputRowOpIdx = 2;
+    const unsigned kMSStoreOutputColOpIdx = 3;
+    const unsigned kMSStoreOutputVIdxOpIdx = 4;
+    const unsigned kMSStoreOutputValOpIdx = 5;
+
     // TODO: add operand index for all the OpCodeClass.
   }
 
@@ -1027,6 +1152,15 @@ namespace DXIL {
     LastEntry,
   };
 
+  enum class MeshOutputTopology
+  {
+    Undefined = 0,
+    Line = 1,
+    Triangle = 2,
+
+    LastEntry,
+  };
+
   // Tessellator partitioning, must match D3D_TESSELLATOR_PARTITIONING
   enum class TessellatorPartitioning : unsigned {
     Undefined = 0,
@@ -1166,8 +1300,10 @@ namespace DXIL {
   const uint64_t ShaderFeatureInfo_Barycentrics = 0x20000;
   const uint64_t ShaderFeatureInfo_NativeLowPrecision = 0x40000;
   const uint64_t ShaderFeatureInfo_ShadingRate = 0x80000;
+  const uint64_t ShaderFeatureInfo_Raytracing_Tier_1_1 = 0x100000;
+  const uint64_t ShaderFeatureInfo_SamplerFeedback = 0x200000;
 
-  const unsigned ShaderFeatureInfoCount = 20;
+  const unsigned ShaderFeatureInfoCount = 22;
 
   // DxilSubobjectType must match D3D12_STATE_SUBOBJECT_TYPE, with
   // certain values reserved, since they cannot be used from Dxil.
@@ -1201,6 +1337,17 @@ namespace DXIL {
     LastEntry,
   };
 
+  enum class CommittedStatus : uint32_t {
+    CommittedNothing = 0,
+    CommittedTriangleHit = 1,
+    CommittedProceduralPrimitiveHit = 2,
+  };
+
+  enum class CandidateType : uint32_t {
+    CandidateNonOpaqueTriangle = 0,
+    CandidateProceduralPrimitive = 1,
+  };
+
   inline bool IsValidHitGroupType(HitGroupType type) {
     return (type >= HitGroupType::Triangle && type < HitGroupType::LastEntry);
   }

+ 17 - 1
include/dxc/DXIL/DxilFunctionProps.h

@@ -67,6 +67,19 @@ struct DxilFunctionProps {
       };
       unsigned attributeSizeInBytes;
     } Ray;
+    // Mesh shader.
+    struct {
+      unsigned numThreads[3];
+      unsigned maxVertexCount;
+      unsigned maxPrimitiveCount;
+      DXIL::MeshOutputTopology outputTopology;
+      // The following doesn't go into metadata
+      unsigned payloadByteSize;
+    } MS;
+    // Amplification shader.
+    struct {
+      unsigned numThreads[3];
+    } AS;
   } ShaderProps;
   DXIL::ShaderKind shaderKind;
   // TODO: Should we have an unmangled name here for ray tracing shaders?
@@ -77,7 +90,8 @@ struct DxilFunctionProps {
   bool IsDS() const     { return shaderKind == DXIL::ShaderKind::Domain; }
   bool IsCS() const     { return shaderKind == DXIL::ShaderKind::Compute; }
   bool IsGraphics() const {
-    return (shaderKind >= DXIL::ShaderKind::Pixel && shaderKind <= DXIL::ShaderKind::Domain);
+    return (shaderKind >= DXIL::ShaderKind::Pixel && shaderKind <= DXIL::ShaderKind::Domain) ||
+           shaderKind == DXIL::ShaderKind::Mesh || shaderKind == DXIL::ShaderKind::Amplification;
   }
   bool IsRayGeneration() const { return shaderKind == DXIL::ShaderKind::RayGeneration; }
   bool IsIntersection() const { return shaderKind == DXIL::ShaderKind::Intersection; }
@@ -88,6 +102,8 @@ struct DxilFunctionProps {
   bool IsRay() const {
     return (shaderKind >= DXIL::ShaderKind::RayGeneration && shaderKind <= DXIL::ShaderKind::Callable);
   }
+  bool IsMS() const { return shaderKind == DXIL::ShaderKind::Mesh; }
+  bool IsAS() const { return shaderKind == DXIL::ShaderKind::Amplification; }
 };
 
 } // namespace hlsl

+ 1365 - 1
include/dxc/DXIL/DxilInstructions.h

@@ -5173,7 +5173,7 @@ struct DxilInst_AcceptHitAndEndSearch {
   bool requiresUniformInputs() const { return false; }
 };
 
-/// This instruction returns the view index
+/// This instruction initiates raytrace
 struct DxilInst_TraceRay {
   llvm::Instruction *Instr;
   // Construction and identification
@@ -5549,5 +5549,1369 @@ struct DxilInst_WaveMultiPrefixBitCount {
   llvm::Value *get_mask3() const { return Instr->getOperand(5); }
   void set_mask3(llvm::Value *val) { Instr->setOperand(5, val); }
 };
+
+/// This instruction Mesh shader intrinsic SetMeshOutputCounts
+struct DxilInst_SetMeshOutputCounts {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_SetMeshOutputCounts(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::SetMeshOutputCounts);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (3 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_numVertices = 1,
+    arg_numPrimitives = 2,
+  };
+  // Accessors
+  llvm::Value *get_numVertices() const { return Instr->getOperand(1); }
+  void set_numVertices(llvm::Value *val) { Instr->setOperand(1, val); }
+  llvm::Value *get_numPrimitives() const { return Instr->getOperand(2); }
+  void set_numPrimitives(llvm::Value *val) { Instr->setOperand(2, val); }
+};
+
+/// This instruction emit a primitive's vertex indices in a mesh shader
+struct DxilInst_EmitIndices {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_EmitIndices(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::EmitIndices);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (5 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_PrimitiveIndex = 1,
+    arg_VertexIndex0 = 2,
+    arg_VertexIndex1 = 3,
+    arg_VertexIndex2 = 4,
+  };
+  // Accessors
+  llvm::Value *get_PrimitiveIndex() const { return Instr->getOperand(1); }
+  void set_PrimitiveIndex(llvm::Value *val) { Instr->setOperand(1, val); }
+  llvm::Value *get_VertexIndex0() const { return Instr->getOperand(2); }
+  void set_VertexIndex0(llvm::Value *val) { Instr->setOperand(2, val); }
+  llvm::Value *get_VertexIndex1() const { return Instr->getOperand(3); }
+  void set_VertexIndex1(llvm::Value *val) { Instr->setOperand(3, val); }
+  llvm::Value *get_VertexIndex2() const { return Instr->getOperand(4); }
+  void set_VertexIndex2(llvm::Value *val) { Instr->setOperand(4, val); }
+};
+
+/// This instruction get the mesh payload which is from amplification shader
+struct DxilInst_GetMeshPayload {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_GetMeshPayload(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::GetMeshPayload);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (1 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+};
+
+/// This instruction stores the value to mesh shader vertex output
+struct DxilInst_StoreVertexOutput {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_StoreVertexOutput(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::StoreVertexOutput);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (6 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_outputSigId = 1,
+    arg_rowIndex = 2,
+    arg_colIndex = 3,
+    arg_value = 4,
+    arg_vertexIndex = 5,
+  };
+  // Accessors
+  llvm::Value *get_outputSigId() const { return Instr->getOperand(1); }
+  void set_outputSigId(llvm::Value *val) { Instr->setOperand(1, val); }
+  llvm::Value *get_rowIndex() const { return Instr->getOperand(2); }
+  void set_rowIndex(llvm::Value *val) { Instr->setOperand(2, val); }
+  llvm::Value *get_colIndex() const { return Instr->getOperand(3); }
+  void set_colIndex(llvm::Value *val) { Instr->setOperand(3, val); }
+  llvm::Value *get_value() const { return Instr->getOperand(4); }
+  void set_value(llvm::Value *val) { Instr->setOperand(4, val); }
+  llvm::Value *get_vertexIndex() const { return Instr->getOperand(5); }
+  void set_vertexIndex(llvm::Value *val) { Instr->setOperand(5, val); }
+};
+
+/// This instruction stores the value to mesh shader primitive output
+struct DxilInst_StorePrimitiveOutput {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_StorePrimitiveOutput(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::StorePrimitiveOutput);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (6 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_outputSigId = 1,
+    arg_rowIndex = 2,
+    arg_colIndex = 3,
+    arg_value = 4,
+    arg_primitiveIndex = 5,
+  };
+  // Accessors
+  llvm::Value *get_outputSigId() const { return Instr->getOperand(1); }
+  void set_outputSigId(llvm::Value *val) { Instr->setOperand(1, val); }
+  llvm::Value *get_rowIndex() const { return Instr->getOperand(2); }
+  void set_rowIndex(llvm::Value *val) { Instr->setOperand(2, val); }
+  llvm::Value *get_colIndex() const { return Instr->getOperand(3); }
+  void set_colIndex(llvm::Value *val) { Instr->setOperand(3, val); }
+  llvm::Value *get_value() const { return Instr->getOperand(4); }
+  void set_value(llvm::Value *val) { Instr->setOperand(4, val); }
+  llvm::Value *get_primitiveIndex() const { return Instr->getOperand(5); }
+  void set_primitiveIndex(llvm::Value *val) { Instr->setOperand(5, val); }
+};
+
+/// This instruction Amplification shader intrinsic DispatchMesh
+struct DxilInst_DispatchMesh {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_DispatchMesh(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::DispatchMesh);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (5 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_threadGroupCountX = 1,
+    arg_threadGroupCountY = 2,
+    arg_threadGroupCountZ = 3,
+    arg_payload = 4,
+  };
+  // Accessors
+  llvm::Value *get_threadGroupCountX() const { return Instr->getOperand(1); }
+  void set_threadGroupCountX(llvm::Value *val) { Instr->setOperand(1, val); }
+  llvm::Value *get_threadGroupCountY() const { return Instr->getOperand(2); }
+  void set_threadGroupCountY(llvm::Value *val) { Instr->setOperand(2, val); }
+  llvm::Value *get_threadGroupCountZ() const { return Instr->getOperand(3); }
+  void set_threadGroupCountZ(llvm::Value *val) { Instr->setOperand(3, val); }
+  llvm::Value *get_payload() const { return Instr->getOperand(4); }
+  void set_payload(llvm::Value *val) { Instr->setOperand(4, val); }
+};
+
+/// This instruction updates a feedback texture for a sampling operation
+struct DxilInst_WriteSamplerFeedback {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_WriteSamplerFeedback(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::WriteSamplerFeedback);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (8 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_feedbackTex = 1,
+    arg_sampledTex = 2,
+    arg_sampler = 3,
+    arg_c0 = 4,
+    arg_c1 = 5,
+    arg_c2 = 6,
+    arg_clamp = 7,
+  };
+  // Accessors
+  llvm::Value *get_feedbackTex() const { return Instr->getOperand(1); }
+  void set_feedbackTex(llvm::Value *val) { Instr->setOperand(1, val); }
+  llvm::Value *get_sampledTex() const { return Instr->getOperand(2); }
+  void set_sampledTex(llvm::Value *val) { Instr->setOperand(2, val); }
+  llvm::Value *get_sampler() const { return Instr->getOperand(3); }
+  void set_sampler(llvm::Value *val) { Instr->setOperand(3, val); }
+  llvm::Value *get_c0() const { return Instr->getOperand(4); }
+  void set_c0(llvm::Value *val) { Instr->setOperand(4, val); }
+  llvm::Value *get_c1() const { return Instr->getOperand(5); }
+  void set_c1(llvm::Value *val) { Instr->setOperand(5, val); }
+  llvm::Value *get_c2() const { return Instr->getOperand(6); }
+  void set_c2(llvm::Value *val) { Instr->setOperand(6, val); }
+  llvm::Value *get_clamp() const { return Instr->getOperand(7); }
+  void set_clamp(llvm::Value *val) { Instr->setOperand(7, val); }
+};
+
+/// This instruction updates a feedback texture for a sampling operation with a bias on the mipmap level
+struct DxilInst_WriteSamplerFeedbackBias {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_WriteSamplerFeedbackBias(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::WriteSamplerFeedbackBias);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (9 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_feedbackTex = 1,
+    arg_sampledTex = 2,
+    arg_sampler = 3,
+    arg_c0 = 4,
+    arg_c1 = 5,
+    arg_c2 = 6,
+    arg_bias = 7,
+    arg_clamp = 8,
+  };
+  // Accessors
+  llvm::Value *get_feedbackTex() const { return Instr->getOperand(1); }
+  void set_feedbackTex(llvm::Value *val) { Instr->setOperand(1, val); }
+  llvm::Value *get_sampledTex() const { return Instr->getOperand(2); }
+  void set_sampledTex(llvm::Value *val) { Instr->setOperand(2, val); }
+  llvm::Value *get_sampler() const { return Instr->getOperand(3); }
+  void set_sampler(llvm::Value *val) { Instr->setOperand(3, val); }
+  llvm::Value *get_c0() const { return Instr->getOperand(4); }
+  void set_c0(llvm::Value *val) { Instr->setOperand(4, val); }
+  llvm::Value *get_c1() const { return Instr->getOperand(5); }
+  void set_c1(llvm::Value *val) { Instr->setOperand(5, val); }
+  llvm::Value *get_c2() const { return Instr->getOperand(6); }
+  void set_c2(llvm::Value *val) { Instr->setOperand(6, val); }
+  llvm::Value *get_bias() const { return Instr->getOperand(7); }
+  void set_bias(llvm::Value *val) { Instr->setOperand(7, val); }
+  llvm::Value *get_clamp() const { return Instr->getOperand(8); }
+  void set_clamp(llvm::Value *val) { Instr->setOperand(8, val); }
+};
+
+/// This instruction updates a feedback texture for a sampling operation with a mipmap-level offset
+struct DxilInst_WriteSamplerFeedbackLevel {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_WriteSamplerFeedbackLevel(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::WriteSamplerFeedbackLevel);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (8 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_feedbackTex = 1,
+    arg_sampledTex = 2,
+    arg_sampler = 3,
+    arg_c0 = 4,
+    arg_c1 = 5,
+    arg_c2 = 6,
+    arg_lod = 7,
+  };
+  // Accessors
+  llvm::Value *get_feedbackTex() const { return Instr->getOperand(1); }
+  void set_feedbackTex(llvm::Value *val) { Instr->setOperand(1, val); }
+  llvm::Value *get_sampledTex() const { return Instr->getOperand(2); }
+  void set_sampledTex(llvm::Value *val) { Instr->setOperand(2, val); }
+  llvm::Value *get_sampler() const { return Instr->getOperand(3); }
+  void set_sampler(llvm::Value *val) { Instr->setOperand(3, val); }
+  llvm::Value *get_c0() const { return Instr->getOperand(4); }
+  void set_c0(llvm::Value *val) { Instr->setOperand(4, val); }
+  llvm::Value *get_c1() const { return Instr->getOperand(5); }
+  void set_c1(llvm::Value *val) { Instr->setOperand(5, val); }
+  llvm::Value *get_c2() const { return Instr->getOperand(6); }
+  void set_c2(llvm::Value *val) { Instr->setOperand(6, val); }
+  llvm::Value *get_lod() const { return Instr->getOperand(7); }
+  void set_lod(llvm::Value *val) { Instr->setOperand(7, val); }
+};
+
+/// This instruction updates a feedback texture for a sampling operation with explicit gradients
+struct DxilInst_WriteSamplerFeedbackGrad {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_WriteSamplerFeedbackGrad(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::WriteSamplerFeedbackGrad);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (10 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_feedbackTex = 1,
+    arg_sampledTex = 2,
+    arg_sampler = 3,
+    arg_c0 = 4,
+    arg_c1 = 5,
+    arg_c2 = 6,
+    arg_ddx = 7,
+    arg_ddy = 8,
+    arg_clamp = 9,
+  };
+  // Accessors
+  llvm::Value *get_feedbackTex() const { return Instr->getOperand(1); }
+  void set_feedbackTex(llvm::Value *val) { Instr->setOperand(1, val); }
+  llvm::Value *get_sampledTex() const { return Instr->getOperand(2); }
+  void set_sampledTex(llvm::Value *val) { Instr->setOperand(2, val); }
+  llvm::Value *get_sampler() const { return Instr->getOperand(3); }
+  void set_sampler(llvm::Value *val) { Instr->setOperand(3, val); }
+  llvm::Value *get_c0() const { return Instr->getOperand(4); }
+  void set_c0(llvm::Value *val) { Instr->setOperand(4, val); }
+  llvm::Value *get_c1() const { return Instr->getOperand(5); }
+  void set_c1(llvm::Value *val) { Instr->setOperand(5, val); }
+  llvm::Value *get_c2() const { return Instr->getOperand(6); }
+  void set_c2(llvm::Value *val) { Instr->setOperand(6, val); }
+  llvm::Value *get_ddx() const { return Instr->getOperand(7); }
+  void set_ddx(llvm::Value *val) { Instr->setOperand(7, val); }
+  llvm::Value *get_ddy() const { return Instr->getOperand(8); }
+  void set_ddy(llvm::Value *val) { Instr->setOperand(8, val); }
+  llvm::Value *get_clamp() const { return Instr->getOperand(9); }
+  void set_clamp(llvm::Value *val) { Instr->setOperand(9, val); }
+};
+
+/// This instruction allocates space for RayQuery and return handle
+struct DxilInst_AllocateRayQuery {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_AllocateRayQuery(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::AllocateRayQuery);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (2 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_constRayFlags = 1,
+  };
+  // Accessors
+  llvm::Value *get_constRayFlags() const { return Instr->getOperand(1); }
+  void set_constRayFlags(llvm::Value *val) { Instr->setOperand(1, val); }
+  uint32_t get_constRayFlags_val() const { return (uint32_t)(llvm::dyn_cast<llvm::ConstantInt>(Instr->getOperand(1))->getZExtValue()); }
+  void set_constRayFlags_val(uint32_t val) { Instr->setOperand(1, llvm::Constant::getIntegerValue(llvm::IntegerType::get(Instr->getContext(), 32), llvm::APInt(32, (uint64_t)val))); }
+};
+
+/// This instruction initializes RayQuery for raytrace
+struct DxilInst_RayQuery_TraceRayInline {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_TraceRayInline(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_TraceRayInline);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (13 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+    arg_accelerationStructure = 2,
+    arg_rayFlags = 3,
+    arg_instanceInclusionMask = 4,
+    arg_origin_X = 5,
+    arg_origin_Y = 6,
+    arg_origin_Z = 7,
+    arg_tMin = 8,
+    arg_direction_X = 9,
+    arg_direction_Y = 10,
+    arg_direction_Z = 11,
+    arg_tMax = 12,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+  llvm::Value *get_accelerationStructure() const { return Instr->getOperand(2); }
+  void set_accelerationStructure(llvm::Value *val) { Instr->setOperand(2, val); }
+  llvm::Value *get_rayFlags() const { return Instr->getOperand(3); }
+  void set_rayFlags(llvm::Value *val) { Instr->setOperand(3, val); }
+  llvm::Value *get_instanceInclusionMask() const { return Instr->getOperand(4); }
+  void set_instanceInclusionMask(llvm::Value *val) { Instr->setOperand(4, val); }
+  llvm::Value *get_origin_X() const { return Instr->getOperand(5); }
+  void set_origin_X(llvm::Value *val) { Instr->setOperand(5, val); }
+  llvm::Value *get_origin_Y() const { return Instr->getOperand(6); }
+  void set_origin_Y(llvm::Value *val) { Instr->setOperand(6, val); }
+  llvm::Value *get_origin_Z() const { return Instr->getOperand(7); }
+  void set_origin_Z(llvm::Value *val) { Instr->setOperand(7, val); }
+  llvm::Value *get_tMin() const { return Instr->getOperand(8); }
+  void set_tMin(llvm::Value *val) { Instr->setOperand(8, val); }
+  llvm::Value *get_direction_X() const { return Instr->getOperand(9); }
+  void set_direction_X(llvm::Value *val) { Instr->setOperand(9, val); }
+  llvm::Value *get_direction_Y() const { return Instr->getOperand(10); }
+  void set_direction_Y(llvm::Value *val) { Instr->setOperand(10, val); }
+  llvm::Value *get_direction_Z() const { return Instr->getOperand(11); }
+  void set_direction_Z(llvm::Value *val) { Instr->setOperand(11, val); }
+  llvm::Value *get_tMax() const { return Instr->getOperand(12); }
+  void set_tMax(llvm::Value *val) { Instr->setOperand(12, val); }
+};
+
+/// This instruction advances a ray query
+struct DxilInst_RayQuery_Proceed {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_Proceed(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_Proceed);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (2 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+};
+
+/// This instruction aborts a ray query
+struct DxilInst_RayQuery_Abort {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_Abort(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_Abort);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (2 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+};
+
+/// This instruction commits a non opaque triangle hit
+struct DxilInst_RayQuery_CommitNonOpaqueTriangleHit {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_CommitNonOpaqueTriangleHit(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_CommitNonOpaqueTriangleHit);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (2 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+};
+
+/// This instruction commits a procedural primitive hit
+struct DxilInst_RayQuery_CommitProceduralPrimitiveHit {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_CommitProceduralPrimitiveHit(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_CommitProceduralPrimitiveHit);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (3 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+    arg_t = 2,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+  llvm::Value *get_t() const { return Instr->getOperand(2); }
+  void set_t(llvm::Value *val) { Instr->setOperand(2, val); }
+};
+
+/// This instruction returns uint status (COMMITTED_STATUS) of the committed hit in a ray query
+struct DxilInst_RayQuery_CommittedStatus {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_CommittedStatus(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_CommittedStatus);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (2 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+};
+
+/// This instruction returns uint candidate type (CANDIDATE_TYPE) of the current hit candidate in a ray query, after Proceed() has returned true
+struct DxilInst_RayQuery_CandidateType {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_CandidateType(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_CandidateType);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (2 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+};
+
+/// This instruction returns matrix for transforming from object-space to world-space for a candidate hit.
+struct DxilInst_RayQuery_CandidateObjectToWorld3x4 {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_CandidateObjectToWorld3x4(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_CandidateObjectToWorld3x4);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (4 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+    arg_row = 2,
+    arg_col = 3,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+  llvm::Value *get_row() const { return Instr->getOperand(2); }
+  void set_row(llvm::Value *val) { Instr->setOperand(2, val); }
+  llvm::Value *get_col() const { return Instr->getOperand(3); }
+  void set_col(llvm::Value *val) { Instr->setOperand(3, val); }
+};
+
+/// This instruction returns matrix for transforming from world-space to object-space for a candidate hit.
+struct DxilInst_RayQuery_CandidateWorldToObject3x4 {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_CandidateWorldToObject3x4(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_CandidateWorldToObject3x4);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (4 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+    arg_row = 2,
+    arg_col = 3,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+  llvm::Value *get_row() const { return Instr->getOperand(2); }
+  void set_row(llvm::Value *val) { Instr->setOperand(2, val); }
+  llvm::Value *get_col() const { return Instr->getOperand(3); }
+  void set_col(llvm::Value *val) { Instr->setOperand(3, val); }
+};
+
+/// This instruction returns matrix for transforming from object-space to world-space for a Committed hit.
+struct DxilInst_RayQuery_CommittedObjectToWorld3x4 {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_CommittedObjectToWorld3x4(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_CommittedObjectToWorld3x4);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (4 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+    arg_row = 2,
+    arg_col = 3,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+  llvm::Value *get_row() const { return Instr->getOperand(2); }
+  void set_row(llvm::Value *val) { Instr->setOperand(2, val); }
+  llvm::Value *get_col() const { return Instr->getOperand(3); }
+  void set_col(llvm::Value *val) { Instr->setOperand(3, val); }
+};
+
+/// This instruction returns matrix for transforming from world-space to object-space for a Committed hit.
+struct DxilInst_RayQuery_CommittedWorldToObject3x4 {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_CommittedWorldToObject3x4(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_CommittedWorldToObject3x4);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (4 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+    arg_row = 2,
+    arg_col = 3,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+  llvm::Value *get_row() const { return Instr->getOperand(2); }
+  void set_row(llvm::Value *val) { Instr->setOperand(2, val); }
+  llvm::Value *get_col() const { return Instr->getOperand(3); }
+  void set_col(llvm::Value *val) { Instr->setOperand(3, val); }
+};
+
+/// This instruction returns if current candidate procedural primitive is non opaque
+struct DxilInst_RayQuery_CandidateProceduralPrimitiveNonOpaque {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_CandidateProceduralPrimitiveNonOpaque(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_CandidateProceduralPrimitiveNonOpaque);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (2 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+};
+
+/// This instruction returns if current candidate triangle is front facing
+struct DxilInst_RayQuery_CandidateTriangleFrontFace {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_CandidateTriangleFrontFace(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_CandidateTriangleFrontFace);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (2 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+};
+
+/// This instruction returns if current committed triangle is front facing
+struct DxilInst_RayQuery_CommittedTriangleFrontFace {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_CommittedTriangleFrontFace(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_CommittedTriangleFrontFace);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (2 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+};
+
+/// This instruction returns candidate triangle hit barycentrics
+struct DxilInst_RayQuery_CandidateTriangleBarycentrics {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_CandidateTriangleBarycentrics(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_CandidateTriangleBarycentrics);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (3 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+    arg_component = 2,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+  llvm::Value *get_component() const { return Instr->getOperand(2); }
+  void set_component(llvm::Value *val) { Instr->setOperand(2, val); }
+  int8_t get_component_val() const { return (int8_t)(llvm::dyn_cast<llvm::ConstantInt>(Instr->getOperand(2))->getZExtValue()); }
+  void set_component_val(int8_t val) { Instr->setOperand(2, llvm::Constant::getIntegerValue(llvm::IntegerType::get(Instr->getContext(), 8), llvm::APInt(8, (uint64_t)val))); }
+};
+
+/// This instruction returns committed triangle hit barycentrics
+struct DxilInst_RayQuery_CommittedTriangleBarycentrics {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_CommittedTriangleBarycentrics(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_CommittedTriangleBarycentrics);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (3 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+    arg_component = 2,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+  llvm::Value *get_component() const { return Instr->getOperand(2); }
+  void set_component(llvm::Value *val) { Instr->setOperand(2, val); }
+  int8_t get_component_val() const { return (int8_t)(llvm::dyn_cast<llvm::ConstantInt>(Instr->getOperand(2))->getZExtValue()); }
+  void set_component_val(int8_t val) { Instr->setOperand(2, llvm::Constant::getIntegerValue(llvm::IntegerType::get(Instr->getContext(), 8), llvm::APInt(8, (uint64_t)val))); }
+};
+
+/// This instruction returns ray flags
+struct DxilInst_RayQuery_RayFlags {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_RayFlags(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_RayFlags);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (2 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+};
+
+/// This instruction returns world ray origin
+struct DxilInst_RayQuery_WorldRayOrigin {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_WorldRayOrigin(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_WorldRayOrigin);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (3 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+    arg_component = 2,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+  llvm::Value *get_component() const { return Instr->getOperand(2); }
+  void set_component(llvm::Value *val) { Instr->setOperand(2, val); }
+  int8_t get_component_val() const { return (int8_t)(llvm::dyn_cast<llvm::ConstantInt>(Instr->getOperand(2))->getZExtValue()); }
+  void set_component_val(int8_t val) { Instr->setOperand(2, llvm::Constant::getIntegerValue(llvm::IntegerType::get(Instr->getContext(), 8), llvm::APInt(8, (uint64_t)val))); }
+};
+
+/// This instruction returns world ray direction
+struct DxilInst_RayQuery_WorldRayDirection {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_WorldRayDirection(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_WorldRayDirection);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (3 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+    arg_component = 2,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+  llvm::Value *get_component() const { return Instr->getOperand(2); }
+  void set_component(llvm::Value *val) { Instr->setOperand(2, val); }
+  int8_t get_component_val() const { return (int8_t)(llvm::dyn_cast<llvm::ConstantInt>(Instr->getOperand(2))->getZExtValue()); }
+  void set_component_val(int8_t val) { Instr->setOperand(2, llvm::Constant::getIntegerValue(llvm::IntegerType::get(Instr->getContext(), 8), llvm::APInt(8, (uint64_t)val))); }
+};
+
+/// This instruction returns float representing the parametric starting point for the ray.
+struct DxilInst_RayQuery_RayTMin {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_RayTMin(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_RayTMin);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (2 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+};
+
+/// This instruction returns float representing the parametric point on the ray for the current candidate triangle hit.
+struct DxilInst_RayQuery_CandidateTriangleRayT {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_CandidateTriangleRayT(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_CandidateTriangleRayT);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (2 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+};
+
+/// This instruction returns float representing the parametric point on the ray for the current committed hit.
+struct DxilInst_RayQuery_CommittedRayT {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_CommittedRayT(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_CommittedRayT);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (2 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+};
+
+/// This instruction returns candidate hit instance index
+struct DxilInst_RayQuery_CandidateInstanceIndex {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_CandidateInstanceIndex(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_CandidateInstanceIndex);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (2 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+};
+
+/// This instruction returns candidate hit instance ID
+struct DxilInst_RayQuery_CandidateInstanceID {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_CandidateInstanceID(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_CandidateInstanceID);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (2 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+};
+
+/// This instruction returns candidate hit geometry index
+struct DxilInst_RayQuery_CandidateGeometryIndex {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_CandidateGeometryIndex(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_CandidateGeometryIndex);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (2 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+};
+
+/// This instruction returns candidate hit geometry index
+struct DxilInst_RayQuery_CandidatePrimitiveIndex {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_CandidatePrimitiveIndex(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_CandidatePrimitiveIndex);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (2 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+};
+
+/// This instruction returns candidate hit object ray origin
+struct DxilInst_RayQuery_CandidateObjectRayOrigin {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_CandidateObjectRayOrigin(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_CandidateObjectRayOrigin);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (3 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+    arg_component = 2,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+  llvm::Value *get_component() const { return Instr->getOperand(2); }
+  void set_component(llvm::Value *val) { Instr->setOperand(2, val); }
+  int8_t get_component_val() const { return (int8_t)(llvm::dyn_cast<llvm::ConstantInt>(Instr->getOperand(2))->getZExtValue()); }
+  void set_component_val(int8_t val) { Instr->setOperand(2, llvm::Constant::getIntegerValue(llvm::IntegerType::get(Instr->getContext(), 8), llvm::APInt(8, (uint64_t)val))); }
+};
+
+/// This instruction returns candidate object ray direction
+struct DxilInst_RayQuery_CandidateObjectRayDirection {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_CandidateObjectRayDirection(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_CandidateObjectRayDirection);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (3 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+    arg_component = 2,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+  llvm::Value *get_component() const { return Instr->getOperand(2); }
+  void set_component(llvm::Value *val) { Instr->setOperand(2, val); }
+  int8_t get_component_val() const { return (int8_t)(llvm::dyn_cast<llvm::ConstantInt>(Instr->getOperand(2))->getZExtValue()); }
+  void set_component_val(int8_t val) { Instr->setOperand(2, llvm::Constant::getIntegerValue(llvm::IntegerType::get(Instr->getContext(), 8), llvm::APInt(8, (uint64_t)val))); }
+};
+
+/// This instruction returns committed hit instance index
+struct DxilInst_RayQuery_CommittedInstanceIndex {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_CommittedInstanceIndex(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_CommittedInstanceIndex);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (2 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+};
+
+/// This instruction returns committed hit instance ID
+struct DxilInst_RayQuery_CommittedInstanceID {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_CommittedInstanceID(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_CommittedInstanceID);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (2 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+};
+
+/// This instruction returns committed hit geometry index
+struct DxilInst_RayQuery_CommittedGeometryIndex {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_CommittedGeometryIndex(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_CommittedGeometryIndex);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (2 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+};
+
+/// This instruction returns committed hit geometry index
+struct DxilInst_RayQuery_CommittedPrimitiveIndex {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_CommittedPrimitiveIndex(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_CommittedPrimitiveIndex);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (2 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+};
+
+/// This instruction returns committed hit object ray origin
+struct DxilInst_RayQuery_CommittedObjectRayOrigin {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_CommittedObjectRayOrigin(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_CommittedObjectRayOrigin);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (3 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+    arg_component = 2,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+  llvm::Value *get_component() const { return Instr->getOperand(2); }
+  void set_component(llvm::Value *val) { Instr->setOperand(2, val); }
+  int8_t get_component_val() const { return (int8_t)(llvm::dyn_cast<llvm::ConstantInt>(Instr->getOperand(2))->getZExtValue()); }
+  void set_component_val(int8_t val) { Instr->setOperand(2, llvm::Constant::getIntegerValue(llvm::IntegerType::get(Instr->getContext(), 8), llvm::APInt(8, (uint64_t)val))); }
+};
+
+/// This instruction returns committed object ray direction
+struct DxilInst_RayQuery_CommittedObjectRayDirection {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_RayQuery_CommittedObjectRayDirection(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::RayQuery_CommittedObjectRayDirection);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (3 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+    arg_component = 2,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+  llvm::Value *get_component() const { return Instr->getOperand(2); }
+  void set_component(llvm::Value *val) { Instr->setOperand(2, val); }
+  int8_t get_component_val() const { return (int8_t)(llvm::dyn_cast<llvm::ConstantInt>(Instr->getOperand(2))->getZExtValue()); }
+  void set_component_val(int8_t val) { Instr->setOperand(2, llvm::Constant::getIntegerValue(llvm::IntegerType::get(Instr->getContext(), 8), llvm::APInt(8, (uint64_t)val))); }
+};
+
+/// This instruction The autogenerated index of the current geometry in the bottom-level structure
+struct DxilInst_GeometryIndex {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_GeometryIndex(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::GeometryIndex);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (1 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+};
 // INSTR-HELPER:END
 } // namespace hlsl

+ 36 - 0
include/dxc/DXIL/DxilMetadataHelper.h

@@ -46,6 +46,7 @@ class DxilSampler;
 class DxilTypeSystem;
 class DxilStructAnnotation;
 class DxilFieldAnnotation;
+class DxilTemplateArgAnnotation;
 class DxilFunctionAnnotation;
 class DxilParameterAnnotation;
 class RootSignatureHandle;
@@ -191,6 +192,13 @@ public:
   static const unsigned kDxilFieldAnnotationCompTypeTag           = 7;
   static const unsigned kDxilFieldAnnotationPreciseTag            = 8;
 
+  // StructAnnotation extended property tags (DXIL 1.5+ only, appended)
+  static const unsigned kDxilTemplateArgumentsTag                 = 0;  // Name for name-value list of extended struct properties
+  // TemplateArgument tags
+  static const unsigned kDxilTemplateArgTypeTag                   = 0;  // Type template argument, followed by undef of type
+  static const unsigned kDxilTemplateArgIntegralTag               = 1;  // Integral template argument, followed by i64 value
+  static const unsigned kDxilTemplateArgValue                     = 1;  // Position of template arg value (type or int)
+
   // Control flow hint.
   static const char kDxilControlFlowHintMDName[];
 
@@ -220,6 +228,8 @@ public:
   static const unsigned kDxilRayPayloadSizeTag  = 6;
   static const unsigned kDxilRayAttribSizeTag   = 7;
   static const unsigned kDxilShaderKindTag      = 8;
+  static const unsigned kDxilMSStateTag         = 9;
+  static const unsigned kDxilASStateTag         = 10;
 
   // GSState.
   static const unsigned kDxilGSStateNumFields               = 5;
@@ -244,6 +254,17 @@ public:
   static const unsigned kDxilHSStateTessellatorOutputPrimitive= 5;
   static const unsigned kDxilHSStateMaxTessellationFactor     = 6;
 
+  // MSState.
+  static const unsigned kDxilMSStateNumFields = 4;
+  static const unsigned kDxilMSStateNumThreads = 0;
+  static const unsigned kDxilMSStateMaxVertexCount = 1;
+  static const unsigned kDxilMSStateMaxPrimitiveCount = 2;
+  static const unsigned kDxilMSStateOutputTopology = 3;
+
+  // ASState.
+  static const unsigned kDxilASStateNumFields = 1;
+  static const unsigned kDxilASStateNumThreads = 0;
+
 public:
   /// Use this class to manipulate metadata of DXIL or high-level DX IR specific fields in the record.
   class ExtraPropertyHelper {
@@ -351,6 +372,8 @@ public:
   void LoadDxilParamAnnotation(const llvm::MDOperand &MDO, DxilParameterAnnotation &PA);
   llvm::Metadata *EmitDxilParamAnnotations(const DxilFunctionAnnotation &FA);
   void LoadDxilParamAnnotations(const llvm::MDOperand &MDO, DxilFunctionAnnotation &FA);
+  llvm::Metadata *EmitDxilTemplateArgAnnotation(const DxilTemplateArgAnnotation &annotation);
+  void LoadDxilTemplateArgAnnotation(const llvm::MDOperand &MDO, DxilTemplateArgAnnotation &annotation);
 
   // Function props.
   llvm::MDTuple *EmitDxilFunctionProps(const hlsl::DxilFunctionProps *props,
@@ -406,6 +429,19 @@ private:
                        DXIL::TessellatorPartitioning &TessPartitioning,
                        DXIL::TessellatorOutputPrimitive &TessOutputPrimitive,
                        float &MaxTessFactor);
+
+  llvm::MDTuple *EmitDxilMSState(const unsigned *NumThreads,
+                                 unsigned MaxVertexCount,
+                                 unsigned MaxPrimitiveCount,
+                                 DXIL::MeshOutputTopology OutputTopology);
+  void LoadDxilMSState(const llvm::MDOperand &MDO,
+                       unsigned *NumThreads,
+                       unsigned &MaxVertexCount,
+                       unsigned &MaxPrimitiveCount,
+                       DXIL::MeshOutputTopology &OutputTopology);
+
+  llvm::MDTuple *EmitDxilASState(const unsigned *NumThreads);
+  void LoadDxilASState(const llvm::MDOperand &MDO, unsigned *NumThreads);
 public:
   // Utility functions.
   static bool IsKnownNamedMetaData(const llvm::NamedMDNode &Node);

+ 12 - 2
include/dxc/DXIL/DxilModule.h

@@ -118,8 +118,8 @@ public:
   const DxilSignature &GetInputSignature() const;
   DxilSignature &GetOutputSignature();
   const DxilSignature &GetOutputSignature() const;
-  DxilSignature &GetPatchConstantSignature();
-  const DxilSignature &GetPatchConstantSignature() const;
+  DxilSignature &GetPatchConstOrPrimSignature();
+  const DxilSignature &GetPatchConstOrPrimSignature() const;
   const std::vector<uint8_t> &GetSerializedRootSignature() const;
   std::vector<uint8_t> &GetSerializedRootSignature();
 
@@ -273,6 +273,16 @@ public:
   float GetMaxTessellationFactor() const;
   void SetMaxTessellationFactor(float MaxTessellationFactor);
 
+  // Mesh shader
+  unsigned GetMaxOutputVertices() const;
+  void SetMaxOutputVertices(unsigned NumOVs);
+  unsigned GetMaxOutputPrimitives() const;
+  void SetMaxOutputPrimitives(unsigned NumOPs);
+  DXIL::MeshOutputTopology GetMeshOutputTopology() const;
+  void SetMeshOutputTopology(DXIL::MeshOutputTopology MeshOutputTopology);
+  unsigned GetPayloadByteSize() const;
+  void SetPayloadByteSize(unsigned Size);
+
   // AutoBindingSpace also enables automatic binding for libraries if set.
   // UINT_MAX == unset
   void SetAutoBindingSpace(uint32_t Space);

+ 1 - 0
include/dxc/DXIL/DxilOperations.h

@@ -89,6 +89,7 @@ public:
   static const char *GetAtomicOpName(DXIL::AtomicBinOpCode OpCode);
   static OpCodeClass GetOpCodeClass(OpCode OpCode);
   static const char *GetOpCodeClassName(OpCode OpCode);
+  static llvm::Attribute::AttrKind GetMemAccessAttr(OpCode opCode);
   static bool IsOverloadLegal(OpCode OpCode, llvm::Type *pType);
   static bool CheckOpCodeTable();
   static bool IsDxilOpFuncName(llvm::StringRef name);

+ 10 - 1
include/dxc/DXIL/DxilShaderFlags.h

@@ -108,6 +108,12 @@ namespace hlsl {
     void SetShadingRate(bool flag) { m_bShadingRate = flag; }
     bool GetShadingRate() const { return m_bShadingRate; }
 
+    void SetRaytracingTier1_1(bool flag) { m_bRaytracingTier1_1 = flag; }
+    bool GetRaytracingTier1_1() const { return m_bRaytracingTier1_1; }
+
+    void SetSamplerFeedback(bool flag) { m_bSamplerFeedback = flag; }
+    bool GetSamplerFeedback() const { return m_bSamplerFeedback; }
+
   private:
     unsigned m_bDisableOptimizations :1;   // D3D11_1_SB_GLOBAL_FLAG_SKIP_OPTIMIZATION
     unsigned m_bDisableMathRefactoring :1; //~D3D10_SB_GLOBAL_FLAG_REFACTORING_ALLOWED
@@ -143,7 +149,10 @@ namespace hlsl {
 
     unsigned m_bShadingRate : 1;      // SHADER_FEATURE_SHADINGRATE
 
-    unsigned m_align0 : 7;        // align to 32 bit.
+    unsigned m_bRaytracingTier1_1 : 1; // SHADER_FEATURE_RAYTRACING_TIER_1_1
+    unsigned m_bSamplerFeedback : 1; // SHADER_FEATURE_SAMPLER_FEEDBACK
+
+    unsigned m_align0 : 5;        // align to 32 bit.
     uint32_t m_align1;            // align to 64 bit.
   };
 

+ 3 - 1
include/dxc/DXIL/DxilShaderModel.h

@@ -40,6 +40,8 @@ public:
   bool IsCS() const     { return m_Kind == Kind::Compute; }
   bool IsLib() const    { return m_Kind == Kind::Library; }
   bool IsRay() const    { return m_Kind >= Kind::RayGeneration && m_Kind <= Kind::Callable; }
+  bool IsMS() const     { return m_Kind == Kind::Mesh; }
+  bool IsAS() const     { return m_Kind == Kind::Amplification; }
   bool IsValid() const;
   bool IsValidForDxil() const;
   bool IsValidForModule() const;
@@ -96,7 +98,7 @@ private:
               unsigned m_NumInputRegs, unsigned m_NumOutputRegs,
               bool m_bUAVs, bool m_bTypedUavs, unsigned m_UAVRegsLim);
 
-  static const unsigned kNumShaderModels = 63;
+  static const unsigned kNumShaderModels = 65;
   static const ShaderModel ms_ShaderModels[kNumShaderModels];
 
   static const ShaderModel *GetInvalid();

+ 1 - 1
include/dxc/DXIL/DxilSigPoint.h

@@ -34,7 +34,7 @@ public:
 
   bool IsInput() const { return m_SignatureKind == DXIL::SignatureKind::Input; }
   bool IsOutput() const { return m_SignatureKind == DXIL::SignatureKind::Output; }
-  bool IsPatchConstant() const { return m_SignatureKind == DXIL::SignatureKind::PatchConstant; }
+  bool IsPatchConstOrPrim() const { return m_SignatureKind == DXIL::SignatureKind::PatchConstOrPrim; }
 
   Kind GetKind() const { return m_Kind; }
   const char *GetName() const { return m_pszName; }

+ 78 - 52
include/dxc/DXIL/DxilSigPoint.inl

@@ -19,25 +19,29 @@ namespace hlsl {
 // for compatibility purposes.
 // <py::lines('SIGPOINT-TABLE')>hctdb_instrhelp.get_sigpoint_table()</py>
 // SIGPOINT-TABLE:BEGIN
-//   SigPoint, Related, ShaderKind, PackingKind,    SignatureKind
+//   SigPoint, Related, ShaderKind,    PackingKind,    SignatureKind
 #define DO_SIGPOINTS(ROW) \
-  ROW(VSIn,     Invalid, Vertex,     InputAssembler, Input) \
-  ROW(VSOut,    Invalid, Vertex,     Vertex,         Output) \
-  ROW(PCIn,     HSCPIn,  Hull,       None,           Invalid) \
-  ROW(HSIn,     HSCPIn,  Hull,       None,           Invalid) \
-  ROW(HSCPIn,   Invalid, Hull,       Vertex,         Input) \
-  ROW(HSCPOut,  Invalid, Hull,       Vertex,         Output) \
-  ROW(PCOut,    Invalid, Hull,       PatchConstant,  PatchConstant) \
-  ROW(DSIn,     Invalid, Domain,     PatchConstant,  PatchConstant) \
-  ROW(DSCPIn,   Invalid, Domain,     Vertex,         Input) \
-  ROW(DSOut,    Invalid, Domain,     Vertex,         Output) \
-  ROW(GSVIn,    Invalid, Geometry,   Vertex,         Input) \
-  ROW(GSIn,     GSVIn,   Geometry,   None,           Invalid) \
-  ROW(GSOut,    Invalid, Geometry,   Vertex,         Output) \
-  ROW(PSIn,     Invalid, Pixel,      Vertex,         Input) \
-  ROW(PSOut,    Invalid, Pixel,      Target,         Output) \
-  ROW(CSIn,     Invalid, Compute,    None,           Invalid) \
-  ROW(Invalid,  Invalid, Invalid,    Invalid,        Invalid)
+  ROW(VSIn,     Invalid, Vertex,        InputAssembler, Input) \
+  ROW(VSOut,    Invalid, Vertex,        Vertex,         Output) \
+  ROW(PCIn,     HSCPIn,  Hull,          None,           Invalid) \
+  ROW(HSIn,     HSCPIn,  Hull,          None,           Invalid) \
+  ROW(HSCPIn,   Invalid, Hull,          Vertex,         Input) \
+  ROW(HSCPOut,  Invalid, Hull,          Vertex,         Output) \
+  ROW(PCOut,    Invalid, Hull,          PatchConstant,  PatchConstOrPrim) \
+  ROW(DSIn,     Invalid, Domain,        PatchConstant,  PatchConstOrPrim) \
+  ROW(DSCPIn,   Invalid, Domain,        Vertex,         Input) \
+  ROW(DSOut,    Invalid, Domain,        Vertex,         Output) \
+  ROW(GSVIn,    Invalid, Geometry,      Vertex,         Input) \
+  ROW(GSIn,     GSVIn,   Geometry,      None,           Invalid) \
+  ROW(GSOut,    Invalid, Geometry,      Vertex,         Output) \
+  ROW(PSIn,     Invalid, Pixel,         Vertex,         Input) \
+  ROW(PSOut,    Invalid, Pixel,         Target,         Output) \
+  ROW(CSIn,     Invalid, Compute,       None,           Invalid) \
+  ROW(MSIn,     Invalid, Mesh,          None,           Invalid) \
+  ROW(MSOut,    Invalid, Mesh,          Vertex,         Output) \
+  ROW(MSPOut,   Invalid, Mesh,          Vertex,         PatchConstOrPrim) \
+  ROW(ASIn,     Invalid, Amplification, None,           Invalid) \
+  ROW(Invalid,  Invalid, Invalid,       Invalid,        Invalid)
 // SIGPOINT-TABLE:END
 
 const SigPoint SigPoint::ms_SigPoints[kNumSigPointRecords] = {
@@ -49,38 +53,39 @@ const SigPoint SigPoint::ms_SigPoints[kNumSigPointRecords] = {
 
 // <py::lines('INTERPRETATION-TABLE')>hctdb_instrhelp.get_interpretation_table()</py>
 // INTERPRETATION-TABLE:BEGIN
-//   Semantic,               VSIn,         VSOut,    PCIn,         HSIn,         HSCPIn,   HSCPOut,  PCOut,      DSIn,         DSCPIn,   DSOut,    GSVIn,    GSIn,         GSOut,    PSIn,          PSOut,         CSIn
+//   Semantic,               VSIn,         VSOut,    PCIn,         HSIn,         HSCPIn,   HSCPOut,  PCOut,      DSIn,         DSCPIn,   DSOut,    GSVIn,    GSIn,         GSOut,    PSIn,          PSOut,         CSIn,     MSIn,         MSOut,        MSPOut,  ASIn
 #define DO_INTERPRETATION_TABLE(ROW) \
-  ROW(Arbitrary,              Arb,          Arb,      NA,           NA,           Arb,      Arb,      Arb,        Arb,          Arb,      Arb,      Arb,      NA,           Arb,      Arb,           NA,            NA) \
-  ROW(VertexID,               SV,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NA,            NA) \
-  ROW(InstanceID,             SV,           Arb,      NA,           NA,           Arb,      Arb,      NA,         NA,           Arb,      Arb,      Arb,      NA,           Arb,      Arb,           NA,            NA) \
-  ROW(Position,               Arb,          SV,       NA,           NA,           SV,       SV,       Arb,        Arb,          SV,       SV,       SV,       NA,           SV,       SV,            NA,            NA) \
-  ROW(RenderTargetArrayIndex, Arb,          SV,       NA,           NA,           SV,       SV,       Arb,        Arb,          SV,       SV,       SV,       NA,           SV,       SV,            NA,            NA) \
-  ROW(ViewPortArrayIndex,     Arb,          SV,       NA,           NA,           SV,       SV,       Arb,        Arb,          SV,       SV,       SV,       NA,           SV,       SV,            NA,            NA) \
-  ROW(ClipDistance,           Arb,          ClipCull, NA,           NA,           ClipCull, ClipCull, Arb,        Arb,          ClipCull, ClipCull, ClipCull, NA,           ClipCull, ClipCull,      NA,            NA) \
-  ROW(CullDistance,           Arb,          ClipCull, NA,           NA,           ClipCull, ClipCull, Arb,        Arb,          ClipCull, ClipCull, ClipCull, NA,           ClipCull, ClipCull,      NA,            NA) \
-  ROW(OutputControlPointID,   NA,           NA,       NA,           NotInSig,     NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NA,            NA) \
-  ROW(DomainLocation,         NA,           NA,       NA,           NA,           NA,       NA,       NA,         NotInSig,     NA,       NA,       NA,       NA,           NA,       NA,            NA,            NA) \
-  ROW(PrimitiveID,            NA,           NA,       NotInSig,     NotInSig,     NA,       NA,       NA,         NotInSig,     NA,       NA,       NA,       Shadow,       SGV,      SGV,           NA,            NA) \
-  ROW(GSInstanceID,           NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NotInSig,     NA,       NA,            NA,            NA) \
-  ROW(SampleIndex,            NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       Shadow _41,    NA,            NA) \
-  ROW(IsFrontFace,            NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           SGV,      SGV,           NA,            NA) \
-  ROW(Coverage,               NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NotInSig _50,  NotPacked _41, NA) \
-  ROW(InnerCoverage,          NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NotInSig _50,  NA,            NA) \
-  ROW(Target,                 NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            Target,        NA) \
-  ROW(Depth,                  NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NotPacked,     NA) \
-  ROW(DepthLessEqual,         NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NotPacked _50, NA) \
-  ROW(DepthGreaterEqual,      NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NotPacked _50, NA) \
-  ROW(StencilRef,             NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NotPacked _50, NA) \
-  ROW(DispatchThreadID,       NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NA,            NotInSig) \
-  ROW(GroupID,                NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NA,            NotInSig) \
-  ROW(GroupIndex,             NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NA,            NotInSig) \
-  ROW(GroupThreadID,          NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NA,            NotInSig) \
-  ROW(TessFactor,             NA,           NA,       NA,           NA,           NA,       NA,       TessFactor, TessFactor,   NA,       NA,       NA,       NA,           NA,       NA,            NA,            NA) \
-  ROW(InsideTessFactor,       NA,           NA,       NA,           NA,           NA,       NA,       TessFactor, TessFactor,   NA,       NA,       NA,       NA,           NA,       NA,            NA,            NA) \
-  ROW(ViewID,                 NotInSig _61, NA,       NotInSig _61, NotInSig _61, NA,       NA,       NA,         NotInSig _61, NA,       NA,       NA,       NotInSig _61, NA,       NotInSig _61,  NA,            NA) \
-  ROW(Barycentrics,           NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NotPacked _61, NA,            NA) \
-  ROW(ShadingRate,            NA,           SV _64,   NA,           NA,           SV _64,   SV _64,   NA,         NA,           SV _64,   SV _64,   SV _64,   NA,           SV _64,   SV _64,        NA,            NA)
+  ROW(Arbitrary,              Arb,          Arb,      NA,           NA,           Arb,      Arb,      Arb,        Arb,          Arb,      Arb,      Arb,      NA,           Arb,      Arb,           NA,            NA,       NA,           Arb _65,      Arb _65, NA) \
+  ROW(VertexID,               SV,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NA,            NA,       NA,           NA,           NA,      NA) \
+  ROW(InstanceID,             SV,           Arb,      NA,           NA,           Arb,      Arb,      NA,         NA,           Arb,      Arb,      Arb,      NA,           Arb,      Arb,           NA,            NA,       NA,           NA,           NA,      NA) \
+  ROW(Position,               Arb,          SV,       NA,           NA,           SV,       SV,       Arb,        Arb,          SV,       SV,       SV,       NA,           SV,       SV,            NA,            NA,       NA,           SV _65,       NA,      NA) \
+  ROW(RenderTargetArrayIndex, Arb,          SV,       NA,           NA,           SV,       SV,       Arb,        Arb,          SV,       SV,       SV,       NA,           SV,       SV,            NA,            NA,       NA,           NA,           SV _65,  NA) \
+  ROW(ViewPortArrayIndex,     Arb,          SV,       NA,           NA,           SV,       SV,       Arb,        Arb,          SV,       SV,       SV,       NA,           SV,       SV,            NA,            NA,       NA,           NA,           SV _65,  NA) \
+  ROW(ClipDistance,           Arb,          ClipCull, NA,           NA,           ClipCull, ClipCull, Arb,        Arb,          ClipCull, ClipCull, ClipCull, NA,           ClipCull, ClipCull,      NA,            NA,       NA,           ClipCull _65, NA,      NA) \
+  ROW(CullDistance,           Arb,          ClipCull, NA,           NA,           ClipCull, ClipCull, Arb,        Arb,          ClipCull, ClipCull, ClipCull, NA,           ClipCull, ClipCull,      NA,            NA,       NA,           ClipCull _65, NA,      NA) \
+  ROW(OutputControlPointID,   NA,           NA,       NA,           NotInSig,     NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NA,            NA,       NA,           NA,           NA,      NA) \
+  ROW(DomainLocation,         NA,           NA,       NA,           NA,           NA,       NA,       NA,         NotInSig,     NA,       NA,       NA,       NA,           NA,       NA,            NA,            NA,       NA,           NA,           NA,      NA) \
+  ROW(PrimitiveID,            NA,           NA,       NotInSig,     NotInSig,     NA,       NA,       NA,         NotInSig,     NA,       NA,       NA,       Shadow,       SGV,      SGV,           NA,            NA,       NA,           NA,           SV _65,  NA) \
+  ROW(GSInstanceID,           NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NotInSig,     NA,       NA,            NA,            NA,       NA,           NA,           NA,      NA) \
+  ROW(SampleIndex,            NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       Shadow _41,    NA,            NA,       NA,           NA,           NA,      NA) \
+  ROW(IsFrontFace,            NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           SGV,      SGV,           NA,            NA,       NA,           NA,           NA,      NA) \
+  ROW(Coverage,               NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NotInSig _50,  NotPacked _41, NA,       NA,           NA,           NA,      NA) \
+  ROW(InnerCoverage,          NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NotInSig _50,  NA,            NA,       NA,           NA,           NA,      NA) \
+  ROW(Target,                 NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            Target,        NA,       NA,           NA,           NA,      NA) \
+  ROW(Depth,                  NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NotPacked,     NA,       NA,           NA,           NA,      NA) \
+  ROW(DepthLessEqual,         NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NotPacked _50, NA,       NA,           NA,           NA,      NA) \
+  ROW(DepthGreaterEqual,      NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NotPacked _50, NA,       NA,           NA,           NA,      NA) \
+  ROW(StencilRef,             NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NotPacked _50, NA,       NA,           NA,           NA,      NA) \
+  ROW(DispatchThreadID,       NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NA,            NotInSig, NotInSig _65, NA,           NA,      NotInSig _65) \
+  ROW(GroupID,                NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NA,            NotInSig, NotInSig _65, NA,           NA,      NotInSig _65) \
+  ROW(GroupIndex,             NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NA,            NotInSig, NotInSig _65, NA,           NA,      NotInSig _65) \
+  ROW(GroupThreadID,          NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NA,            NotInSig, NotInSig _65, NA,           NA,      NotInSig _65) \
+  ROW(TessFactor,             NA,           NA,       NA,           NA,           NA,       NA,       TessFactor, TessFactor,   NA,       NA,       NA,       NA,           NA,       NA,            NA,            NA,       NA,           NA,           NA,      NA) \
+  ROW(InsideTessFactor,       NA,           NA,       NA,           NA,           NA,       NA,       TessFactor, TessFactor,   NA,       NA,       NA,       NA,           NA,       NA,            NA,            NA,       NA,           NA,           NA,      NA) \
+  ROW(ViewID,                 NotInSig _61, NA,       NotInSig _61, NotInSig _61, NA,       NA,       NA,         NotInSig _61, NA,       NA,       NA,       NotInSig _61, NA,       NotInSig _61,  NA,            NA,       NotInSig _65, NA,           NA,      NA) \
+  ROW(Barycentrics,           NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NotPacked _61, NA,            NA,       NA,           NA,           NA,      NA) \
+  ROW(ShadingRate,            NA,           SV _64,   NA,           NA,           SV _64,   SV _64,   NA,         NA,           SV _64,   SV _64,   SV _64,   NA,           SV _64,   SV _64,        NA,            NA,       NA,           NA,           NA,      NA) \
+  ROW(CullPrimitive,          NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NA,            NA,       NA,           NA,           SV _65,  NA)
 // INTERPRETATION-TABLE:END
 
 const VersionedSemanticInterpretation SigPoint::ms_SemanticInterpretationTable[(unsigned)DXIL::SemanticKind::Invalid][(unsigned)SigPoint::Kind::Invalid] = {
@@ -88,7 +93,8 @@ const VersionedSemanticInterpretation SigPoint::ms_SemanticInterpretationTable[(
 #define _50 ,5,0
 #define _61 ,6,1
 #define _64 ,6,4
-#define DO_ROW(SEM, VSIn, VSOut, PCIn, HSIn, HSCPIn, HSCPOut, PCOut, DSIn, DSCPIn, DSOut, GSVIn, GSIn, GSOut, PSIn, PSOut, CSIn) \
+#define _65 ,6,5
+#define DO_ROW(SEM, VSIn, VSOut, PCIn, HSIn, HSCPIn, HSCPOut, PCOut, DSIn, DSCPIn, DSOut, GSVIn, GSIn, GSOut, PSIn, PSOut, CSIn, MSIn, MSOut, MSPOut, ASIn) \
   { VersionedSemanticInterpretation(DXIL::SemanticInterpretationKind::VSIn), \
     VersionedSemanticInterpretation(DXIL::SemanticInterpretationKind::VSOut), \
     VersionedSemanticInterpretation(DXIL::SemanticInterpretationKind::PCIn), \
@@ -105,6 +111,10 @@ const VersionedSemanticInterpretation SigPoint::ms_SemanticInterpretationTable[(
     VersionedSemanticInterpretation(DXIL::SemanticInterpretationKind::PSIn), \
     VersionedSemanticInterpretation(DXIL::SemanticInterpretationKind::PSOut), \
     VersionedSemanticInterpretation(DXIL::SemanticInterpretationKind::CSIn), \
+    VersionedSemanticInterpretation(DXIL::SemanticInterpretationKind::MSIn), \
+    VersionedSemanticInterpretation(DXIL::SemanticInterpretationKind::MSOut), \
+    VersionedSemanticInterpretation(DXIL::SemanticInterpretationKind::MSPOut), \
+    VersionedSemanticInterpretation(DXIL::SemanticInterpretationKind::ASIn), \
   },
   DO_INTERPRETATION_TABLE(DO_ROW)
 #undef DO_ROW
@@ -179,7 +189,7 @@ DXIL::SigPointKind SigPoint::GetKind(DXIL::ShaderKind shaderKind, DXIL::Signatur
     switch (sigKind) {
     case DXIL::SignatureKind::Input: return DXIL::SigPointKind::HSCPIn;
     case DXIL::SignatureKind::Output: return DXIL::SigPointKind::HSCPOut;
-    case DXIL::SignatureKind::PatchConstant: return DXIL::SigPointKind::PCOut;
+    case DXIL::SignatureKind::PatchConstOrPrim: return DXIL::SigPointKind::PCOut;
     default:
       break;
     }
@@ -188,7 +198,7 @@ DXIL::SigPointKind SigPoint::GetKind(DXIL::ShaderKind shaderKind, DXIL::Signatur
     switch (sigKind) {
     case DXIL::SignatureKind::Input: return DXIL::SigPointKind::DSCPIn;
     case DXIL::SignatureKind::Output: return DXIL::SigPointKind::DSOut;
-    case DXIL::SignatureKind::PatchConstant: return DXIL::SigPointKind::DSIn;
+    case DXIL::SignatureKind::PatchConstOrPrim: return DXIL::SigPointKind::DSIn;
     default:
       break;
     }
@@ -216,6 +226,22 @@ DXIL::SigPointKind SigPoint::GetKind(DXIL::ShaderKind shaderKind, DXIL::Signatur
       break;
     }
     break;
+  case DXIL::ShaderKind::Mesh:
+    switch (sigKind) {
+    case DXIL::SignatureKind::Input: return DXIL::SigPointKind::MSIn;
+    case DXIL::SignatureKind::Output: return DXIL::SigPointKind::MSOut;
+    case DXIL::SignatureKind::PatchConstOrPrim: return DXIL::SigPointKind::MSPOut;
+    default:
+      break;
+    }
+    break;
+  case DXIL::ShaderKind::Amplification:
+    switch (sigKind) {
+    case DXIL::SignatureKind::Input: return DXIL::SigPointKind::ASIn;
+    default:
+      break;
+    }
+    break;
   default:
     break;
   }

+ 4 - 2
include/dxc/DXIL/DxilSignature.h

@@ -52,6 +52,8 @@ public:
 
   static bool ShouldBeAllocated(DXIL::SemanticInterpretationKind);
 
+  unsigned GetRowCount() const;
+
 private:
   DXIL::SigPointKind m_sigPointKind;
   std::vector<std::unique_ptr<DxilSignatureElement> > m_Elements;
@@ -62,12 +64,12 @@ struct DxilEntrySignature {
   DxilEntrySignature(DXIL::ShaderKind shaderKind, bool useMinPrecision)
       : InputSignature(shaderKind, DxilSignature::Kind::Input, useMinPrecision),
         OutputSignature(shaderKind, DxilSignature::Kind::Output, useMinPrecision),
-        PatchConstantSignature(shaderKind, DxilSignature::Kind::PatchConstant, useMinPrecision) {
+        PatchConstOrPrimSignature(shaderKind, DxilSignature::Kind::PatchConstOrPrim, useMinPrecision) {
   }
   DxilEntrySignature(const DxilEntrySignature &src);
   DxilSignature InputSignature;
   DxilSignature OutputSignature;
-  DxilSignature PatchConstantSignature;
+  DxilSignature PatchConstOrPrimSignature;
 };
 
 } // namespace hlsl

+ 1 - 1
include/dxc/DXIL/DxilSignatureElement.h

@@ -50,7 +50,7 @@ public:
 
   bool IsInput() const;
   bool IsOutput() const;
-  bool IsPatchConstant() const;
+  bool IsPatchConstOrPrim() const;
   const char *GetName() const;
   unsigned GetRows() const;
   void SetRows(unsigned Rows);

+ 29 - 1
include/dxc/DXIL/DxilTypeSystem.h

@@ -90,6 +90,22 @@ private:
   std::string m_FieldName;
 };
 
+class DxilTemplateArgAnnotation : DxilFieldAnnotation {
+public:
+  DxilTemplateArgAnnotation();
+
+  bool IsType() const;
+  const llvm::Type *GetType() const;
+  void SetType(const llvm::Type *pType);
+
+  bool IsIntegral() const;
+  int64_t GetIntegral() const;
+  void SetIntegral(int64_t i64);
+
+private:
+  const llvm::Type *m_Type;
+  int64_t m_Integral;
+};
 
 /// Use this class to represent LLVM structure annotation.
 class DxilStructAnnotation {
@@ -105,10 +121,18 @@ public:
   void SetCBufferSize(unsigned size);
   void MarkEmptyStruct();
   bool IsEmptyStruct();
+
+  // For template args, GetNumTemplateArgs() will return 0 if not a template
+  unsigned GetNumTemplateArgs() const;
+  void SetNumTemplateArgs(unsigned count);
+  DxilTemplateArgAnnotation &GetTemplateArgAnnotation(unsigned argIdx);
+  const DxilTemplateArgAnnotation &GetTemplateArgAnnotation(unsigned argIdx) const;
+
 private:
   const llvm::StructType *m_pStructType;
   std::vector<DxilFieldAnnotation> m_FieldAnnotations;
   unsigned m_CBufferSize;  // The size of struct if inside constant buffer.
+  std::vector<DxilTemplateArgAnnotation> m_TemplateAnnotations;
 };
 
 
@@ -123,6 +147,10 @@ enum class DxilParamInputQual {
   OutStream2,
   OutStream3,
   InputPrimitive,
+  OutIndices,
+  OutVertices,
+  OutPrimitives,
+  InPayload,
 };
 
 /// Use this class to represent type annotation for function parameter.
@@ -164,7 +192,7 @@ public:
 
   DxilTypeSystem(llvm::Module *pModule);
 
-  DxilStructAnnotation *AddStructAnnotation(const llvm::StructType *pStructType);
+  DxilStructAnnotation *AddStructAnnotation(const llvm::StructType *pStructType, unsigned numTemplateArgs = 0);
   DxilStructAnnotation *GetStructAnnotation(const llvm::StructType *pStructType);
   const DxilStructAnnotation *GetStructAnnotation(const llvm::StructType *pStructType) const;
   void EraseStructAnnotation(const llvm::StructType *pStructType);

+ 1 - 0
include/dxc/DXIL/DxilUtil.h

@@ -109,6 +109,7 @@ namespace dxilutil {
   bool ContainsHLSLObjectType(llvm::Type *Ty);
   bool IsHLSLResourceType(llvm::Type *Ty);
   bool IsHLSLObjectType(llvm::Type *Ty);
+  bool IsHLSLRayQueryType(llvm::Type *Ty);
   bool IsSplat(llvm::ConstantDataVector *cdv);
 
   llvm::Type* StripArrayTypes(llvm::Type *Ty, llvm::SmallVectorImpl<unsigned> *OuterToInnerLengths = nullptr);

+ 91 - 74
include/dxc/DxilContainer/DxilPipelineStateValidation.h

@@ -49,6 +49,13 @@ struct PSInfo {
   char DepthOutput;
   char SampleFrequency;
 };
+struct MSInfo {
+  uint32_t GroupSharedBytesUsed;
+  uint32_t GroupSharedBytesDependentOnViewID;
+  uint32_t PayloadSizeInBytes;
+  uint16_t MaxOutputVertices;
+  uint16_t MaxOutputPrimitives;
+};
 
 // Versioning is additive and based on size
 struct PSVRuntimeInfo0
@@ -59,6 +66,7 @@ struct PSVRuntimeInfo0
     DSInfo DS;
     GSInfo GS;
     PSInfo PS;
+    MSInfo MS;
   };
   uint32_t MinimumExpectedWaveLaneCount;  // minimum lane count required, 0 if unused
   uint32_t MaximumExpectedWaveLaneCount;  // maximum lane count required, 0xffffffff if unused
@@ -79,6 +87,8 @@ enum class PSVShaderKind : uint8_t    // DXIL::ShaderKind
   ClosestHit,
   Miss,
   Callable,
+  Mesh,
+  Amplification,
   Invalid,
 };
 
@@ -88,16 +98,19 @@ struct PSVRuntimeInfo1 : public PSVRuntimeInfo0
   uint8_t UsesViewID;
   union {
     uint16_t MaxVertexCount;          // MaxVertexCount for GS only (max 1024)
-    uint8_t SigPatchConstantVectors;  // Output for HS; Input for DS
+    uint8_t SigPatchConstOrPrimVectors;  // Output for HS; Input for DS; Primitive output for MS
   };
 
   // PSVSignatureElement counts
   uint8_t SigInputElements;
   uint8_t SigOutputElements;
-  uint8_t SigPatchConstantElements;
+  uint8_t SigPatchConstOrPrimElements;
 
   // Number of packed vectors per signature
-  uint8_t SigInputVectors;
+  union {
+    uint8_t SigInputVectors;
+    uint8_t MeshOutputTopology;
+  };
   uint8_t SigOutputVectors[4];      // Array for GS Stream Out Index
 };
 
@@ -320,9 +333,9 @@ struct PSVInitInfo
     UsesViewID(0),
     SigInputElements(0),
     SigOutputElements(0),
-    SigPatchConstantElements(0),
+    SigPatchConstOrPrimElements(0),
     SigInputVectors(0),
-    SigPatchConstantVectors(0)
+    SigPatchConstOrPrimVectors(0)
   {}
   uint32_t PSVVersion;
   uint32_t ResourceCount;
@@ -332,9 +345,9 @@ struct PSVInitInfo
   uint8_t UsesViewID;
   uint8_t SigInputElements;
   uint8_t SigOutputElements;
-  uint8_t SigPatchConstantElements;
+  uint8_t SigPatchConstOrPrimElements;
   uint8_t SigInputVectors;
-  uint8_t SigPatchConstantVectors;
+  uint8_t SigPatchConstOrPrimVectors;
   uint8_t SigOutputVectors[4] = {0, 0, 0, 0};
 };
 
@@ -351,9 +364,9 @@ class DxilPipelineStateValidation
   uint32_t m_uPSVSignatureElementSize;
   void* m_pSigInputElements;
   void* m_pSigOutputElements;
-  void* m_pSigPatchConstantElements;
+  void* m_pSigPatchConstOrPrimElements;
   uint32_t* m_pViewIDOutputMask;
-  uint32_t* m_pViewIDPCOutputMask;
+  uint32_t* m_pViewIDPCOrPrimOutputMask;
   uint32_t* m_pInputToOutputTable;
   uint32_t* m_pInputToPCOutputTable;
   uint32_t* m_pPCInputToOutputTable;
@@ -371,9 +384,9 @@ public:
     m_uPSVSignatureElementSize(0),
     m_pSigInputElements(nullptr),
     m_pSigOutputElements(nullptr),
-    m_pSigPatchConstantElements(nullptr),
+    m_pSigPatchConstOrPrimElements(nullptr),
     m_pViewIDOutputMask(nullptr),
-    m_pViewIDPCOutputMask(nullptr),
+    m_pViewIDPCOrPrimOutputMask(nullptr),
     m_pInputToOutputTable(nullptr),
     m_pInputToPCOutputTable(nullptr),
     m_pPCInputToOutputTable(nullptr)
@@ -394,28 +407,28 @@ public:
   //    uint32_t SemanticIndexTableEntries (number of dwords)
   //    If SemanticIndexTableEntries:
   //      { semantic index } * SemanticIndexTableEntries
-  //    If SigInputElements || SigOutputElements || SigPatchConstantElements:
+  //    If SigInputElements || SigOutputElements || SigPatchConstOrPrimElements:
   //      uint32_t PSVSignatureElement_size
   //      { PSVSignatureElementN structure } * SigInputElements
   //      { PSVSignatureElementN structure } * SigOutputElements
-  //      { PSVSignatureElementN structure } * SigPatchConstantElements
+  //      { PSVSignatureElementN structure } * SigPatchConstOrPrimElements
   //    If (UsesViewID):
   //      For (i : each stream index 0-3):
   //        If (SigOutputVectors[i] non-zero):
   //          { uint32_t * PSVComputeMaskDwordsFromVectors(SigOutputVectors[i]) }
   //            - Outputs affected by ViewID as a bitmask
-  //      If (HS and SigPatchConstantVectors non-zero):
-  //        { uint32_t * PSVComputeMaskDwordsFromVectors(SigPatchConstantVectors) }
+  //      If (HS and SigPatchConstOrPrimVectors non-zero):
+  //        { uint32_t * PSVComputeMaskDwordsFromVectors(SigPatchConstOrPrimVectors) }
   //          - PCOutputs affected by ViewID as a bitmask
   //    For (i : each stream index 0-3):
   //      If (SigInputVectors and SigOutputVectors[i] non-zero):
   //        { PSVComputeInputOutputTableSize(SigInputVectors, SigOutputVectors[i]) }
   //          - Outputs affected by inputs as a table of bitmasks
-  //    If (HS and SigPatchConstantVectors and SigInputVectors non-zero):
-  //      { PSVComputeInputOutputTableSize(SigInputVectors, SigPatchConstantVectors) }
+  //    If (HS and SigPatchConstOrPrimVectors and SigInputVectors non-zero):
+  //      { PSVComputeInputOutputTableSize(SigInputVectors, SigPatchConstOrPrimVectors) }
   //        - Patch constant outputs affected by inputs as a table of bitmasks
-  //    If (DS and SigOutputVectors[0] and SigPatchConstantVectors non-zero):
-  //      { PSVComputeInputOutputTableSize(SigPatchConstantVectors, SigOutputVectors[0]) }
+  //    If (DS and SigOutputVectors[0] and SigPatchConstOrPrimVectors non-zero):
+  //      { PSVComputeInputOutputTableSize(SigPatchConstOrPrimVectors, SigOutputVectors[0]) }
   //        - Outputs affected by patch constant inputs as a table of bitmasks
   // returns true if no errors occurred.
   bool InitFromPSV0(const void* pBits, uint32_t size) {
@@ -466,14 +479,14 @@ public:
       pCurBits += sizeof(uint32_t) * m_SemanticIndexTable.Entries;
 
       // Dxil Signature Elements
-      if (m_pPSVRuntimeInfo1->SigInputElements || m_pPSVRuntimeInfo1->SigOutputElements || m_pPSVRuntimeInfo1->SigPatchConstantElements) {
+      if (m_pPSVRuntimeInfo1->SigInputElements || m_pPSVRuntimeInfo1->SigOutputElements || m_pPSVRuntimeInfo1->SigPatchConstOrPrimElements) {
         minsize += sizeof(uint32_t);
         if (!(size >= minsize)) return false;
         m_uPSVSignatureElementSize = *(uint32_t*)pCurBits;
         if (m_uPSVSignatureElementSize < sizeof(PSVSignatureElement0))
           return false;   // Illegal: Size smaller than first version
         pCurBits += sizeof(uint32_t);
-        minsize += m_uPSVSignatureElementSize * (m_pPSVRuntimeInfo1->SigInputElements + m_pPSVRuntimeInfo1->SigOutputElements + m_pPSVRuntimeInfo1->SigPatchConstantElements);
+        minsize += m_uPSVSignatureElementSize * (m_pPSVRuntimeInfo1->SigInputElements + m_pPSVRuntimeInfo1->SigOutputElements + m_pPSVRuntimeInfo1->SigPatchConstOrPrimElements);
         if (!(size >= minsize)) return false;
       }
       if (m_pPSVRuntimeInfo1->SigInputElements) {
@@ -484,9 +497,9 @@ public:
         m_pSigOutputElements = (PSVSignatureElement0*)pCurBits;
         pCurBits += m_uPSVSignatureElementSize * m_pPSVRuntimeInfo1->SigOutputElements;
       }
-      if (m_pPSVRuntimeInfo1->SigPatchConstantElements) {
-        m_pSigPatchConstantElements = (PSVSignatureElement0*)pCurBits;
-        pCurBits += m_uPSVSignatureElementSize * m_pPSVRuntimeInfo1->SigPatchConstantElements;
+      if (m_pPSVRuntimeInfo1->SigPatchConstOrPrimElements) {
+        m_pSigPatchConstOrPrimElements = (PSVSignatureElement0*)pCurBits;
+        pCurBits += m_uPSVSignatureElementSize * m_pPSVRuntimeInfo1->SigPatchConstOrPrimElements;
       }
 
       // ViewID dependencies
@@ -501,11 +514,11 @@ public:
           if (!IsGS())
             break;
         }
-        if (IsHS() && m_pPSVRuntimeInfo1->SigPatchConstantVectors) {
-          minsize += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(m_pPSVRuntimeInfo1->SigPatchConstantVectors);
+        if ((IsHS() || IsMS()) && m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors) {
+          minsize += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors);
           if (!(size >= minsize)) return false;
-          m_pViewIDPCOutputMask = (uint32_t*)pCurBits;
-          pCurBits += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(m_pPSVRuntimeInfo1->SigPatchConstantVectors);
+          m_pViewIDPCOrPrimOutputMask = (uint32_t*)pCurBits;
+          pCurBits += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors);
         }
       }
 
@@ -520,17 +533,17 @@ public:
         if (!IsGS())
           break;
       }
-      if (IsHS() && m_pPSVRuntimeInfo1->SigPatchConstantVectors > 0 && m_pPSVRuntimeInfo1->SigInputVectors > 0) {
-        minsize += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigPatchConstantVectors);
+      if ((IsHS() || IsMS()) && m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors > 0 && m_pPSVRuntimeInfo1->SigInputVectors > 0) {
+        minsize += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors);
         if (!(size >= minsize)) return false;
         m_pInputToPCOutputTable = (uint32_t*)pCurBits;
-        pCurBits += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigPatchConstantVectors);
+        pCurBits += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors);
       }
-      if (IsDS() && m_pPSVRuntimeInfo1->SigOutputVectors[0] > 0 && m_pPSVRuntimeInfo1->SigPatchConstantVectors > 0) {
-        minsize += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigPatchConstantVectors, m_pPSVRuntimeInfo1->SigOutputVectors[0]);
+      if (IsDS() && m_pPSVRuntimeInfo1->SigOutputVectors[0] > 0 && m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors > 0) {
+        minsize += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors, m_pPSVRuntimeInfo1->SigOutputVectors[0]);
         if (!(size >= minsize)) return false;
         m_pPCInputToOutputTable = (uint32_t*)pCurBits;
-        pCurBits += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigPatchConstantVectors, m_pPSVRuntimeInfo1->SigOutputVectors[0]);
+        pCurBits += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors, m_pPSVRuntimeInfo1->SigOutputVectors[0]);
       }
     }
     return true;
@@ -567,12 +580,12 @@ public:
     if (initInfo.PSVVersion > 0) {
       size += sizeof(uint32_t) + PSVALIGN4(initInfo.StringTable.Size);
       size += sizeof(uint32_t) + sizeof(uint32_t) * initInfo.SemanticIndexTable.Entries;
-      if (initInfo.SigInputElements || initInfo.SigOutputElements || initInfo.SigPatchConstantElements) {
+      if (initInfo.SigInputElements || initInfo.SigOutputElements || initInfo.SigPatchConstOrPrimElements) {
         size += sizeof(uint32_t);   // PSVSignatureElement_size
       }
       size += m_uPSVSignatureElementSize * initInfo.SigInputElements;
       size += m_uPSVSignatureElementSize * initInfo.SigOutputElements;
-      size += m_uPSVSignatureElementSize * initInfo.SigPatchConstantElements;
+      size += m_uPSVSignatureElementSize * initInfo.SigPatchConstOrPrimElements;
 
       if (initInfo.UsesViewID) {
         for (unsigned i = 0; i < 4; i++) {
@@ -580,21 +593,23 @@ public:
           if (initInfo.ShaderStage != PSVShaderKind::Geometry)
             break;
         }
-        if (initInfo.ShaderStage == PSVShaderKind::Hull)
-          size += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(initInfo.SigPatchConstantVectors);
+        if (initInfo.ShaderStage == PSVShaderKind::Hull || initInfo.ShaderStage == PSVShaderKind::Mesh)
+          size += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(initInfo.SigPatchConstOrPrimVectors);
       }
-      for (unsigned i = 0; i < 4; i++) {
-        if (initInfo.SigOutputVectors[i] > 0 && initInfo.SigInputVectors > 0) {
-          size += PSVComputeInputOutputTableSize(initInfo.SigInputVectors, initInfo.SigOutputVectors[i]);
-          if (initInfo.ShaderStage != PSVShaderKind::Geometry)
-            break;
+      if (initInfo.ShaderStage != PSVShaderKind::Mesh && initInfo.ShaderStage != PSVShaderKind::Amplification) {
+        for (unsigned i = 0; i < 4; i++) {
+          if (initInfo.SigOutputVectors[i] > 0 && initInfo.SigInputVectors > 0) {
+            size += PSVComputeInputOutputTableSize(initInfo.SigInputVectors, initInfo.SigOutputVectors[i]);
+            if (initInfo.ShaderStage != PSVShaderKind::Geometry)
+              break;
+          }
+        }
+        if (initInfo.ShaderStage == PSVShaderKind::Hull && initInfo.SigPatchConstOrPrimVectors > 0 && initInfo.SigInputVectors > 0) {
+          size += PSVComputeInputOutputTableSize(initInfo.SigInputVectors, initInfo.SigPatchConstOrPrimVectors);
+        }
+        if (initInfo.ShaderStage == PSVShaderKind::Domain && initInfo.SigOutputVectors[0] > 0 && initInfo.SigPatchConstOrPrimVectors > 0) {
+          size += PSVComputeInputOutputTableSize(initInfo.SigPatchConstOrPrimVectors, initInfo.SigOutputVectors[0]);
         }
-      }
-      if (initInfo.ShaderStage == PSVShaderKind::Hull && initInfo.SigPatchConstantVectors > 0 && initInfo.SigInputVectors > 0) {
-        size += PSVComputeInputOutputTableSize(initInfo.SigInputVectors, initInfo.SigPatchConstantVectors);
-      }
-      if (initInfo.ShaderStage == PSVShaderKind::Domain && initInfo.SigOutputVectors[0] > 0 && initInfo.SigPatchConstantVectors > 0) {
-        size += PSVComputeInputOutputTableSize(initInfo.SigPatchConstantVectors, initInfo.SigOutputVectors[0]);
       }
     }
 
@@ -634,11 +649,11 @@ public:
       m_pPSVRuntimeInfo1->UsesViewID = initInfo.UsesViewID;
       m_pPSVRuntimeInfo1->SigInputElements = initInfo.SigInputElements;
       m_pPSVRuntimeInfo1->SigOutputElements = initInfo.SigOutputElements;
-      m_pPSVRuntimeInfo1->SigPatchConstantElements = initInfo.SigPatchConstantElements;
+      m_pPSVRuntimeInfo1->SigPatchConstOrPrimElements = initInfo.SigPatchConstOrPrimElements;
       m_pPSVRuntimeInfo1->SigInputVectors = initInfo.SigInputVectors;
       memcpy(m_pPSVRuntimeInfo1->SigOutputVectors, initInfo.SigOutputVectors, 4);
-      if (IsHS() || IsDS()) {
-        m_pPSVRuntimeInfo1->SigPatchConstantVectors = initInfo.SigPatchConstantVectors;
+      if (IsHS() || IsDS() || IsMS()) {
+        m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors = initInfo.SigPatchConstOrPrimVectors;
       }
 
       // Note: if original size was unaligned, padding has already been zero initialized
@@ -657,7 +672,7 @@ public:
       pCurBits += sizeof(uint32_t) * m_SemanticIndexTable.Entries;
 
       // Dxil Signature Elements
-      if (m_pPSVRuntimeInfo1->SigInputElements || m_pPSVRuntimeInfo1->SigOutputElements || m_pPSVRuntimeInfo1->SigPatchConstantElements) {
+      if (m_pPSVRuntimeInfo1->SigInputElements || m_pPSVRuntimeInfo1->SigOutputElements || m_pPSVRuntimeInfo1->SigPatchConstOrPrimElements) {
         *(uint32_t*)pCurBits = m_uPSVSignatureElementSize;
         pCurBits += sizeof(uint32_t);
       }
@@ -669,9 +684,9 @@ public:
         m_pSigOutputElements = (PSVSignatureElement0*)pCurBits;
         pCurBits += m_uPSVSignatureElementSize * m_pPSVRuntimeInfo1->SigOutputElements;
       }
-      if (m_pPSVRuntimeInfo1->SigPatchConstantElements) {
-        m_pSigPatchConstantElements = (PSVSignatureElement0*)pCurBits;
-        pCurBits += m_uPSVSignatureElementSize * m_pPSVRuntimeInfo1->SigPatchConstantElements;
+      if (m_pPSVRuntimeInfo1->SigPatchConstOrPrimElements) {
+        m_pSigPatchConstOrPrimElements = (PSVSignatureElement0*)pCurBits;
+        pCurBits += m_uPSVSignatureElementSize * m_pPSVRuntimeInfo1->SigPatchConstOrPrimElements;
       }
 
       // ViewID dependencies
@@ -684,9 +699,9 @@ public:
           if (!IsGS())
             break;
         }
-        if (IsHS() && m_pPSVRuntimeInfo1->SigPatchConstantVectors) {
-          m_pViewIDPCOutputMask = (uint32_t*)pCurBits;
-          pCurBits += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(m_pPSVRuntimeInfo1->SigPatchConstantVectors);
+        if ((IsHS() || IsMS()) && m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors) {
+          m_pViewIDPCOrPrimOutputMask = (uint32_t*)pCurBits;
+          pCurBits += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors);
         }
       }
 
@@ -699,13 +714,13 @@ public:
         if (!IsGS())
           break;
       }
-      if (IsHS() && m_pPSVRuntimeInfo1->SigPatchConstantVectors > 0 && m_pPSVRuntimeInfo1->SigInputVectors > 0) {
+      if ((IsHS() || IsMS()) && m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors > 0 && m_pPSVRuntimeInfo1->SigInputVectors > 0) {
         m_pInputToPCOutputTable = (uint32_t*)pCurBits;
-        pCurBits += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigPatchConstantVectors);
+        pCurBits += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors);
       }
-      if (IsDS() && m_pPSVRuntimeInfo1->SigOutputVectors[0] > 0 && m_pPSVRuntimeInfo1->SigPatchConstantVectors > 0) {
+      if (IsDS() && m_pPSVRuntimeInfo1->SigOutputVectors[0] > 0 && m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors > 0) {
         m_pPCInputToOutputTable = (uint32_t*)pCurBits;
-        pCurBits += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigPatchConstantVectors, m_pPSVRuntimeInfo1->SigOutputVectors[0]);
+        pCurBits += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors, m_pPSVRuntimeInfo1->SigOutputVectors[0]);
       }
     }
 
@@ -747,9 +762,9 @@ public:
       return m_pPSVRuntimeInfo1->SigOutputElements;
     return 0;
   }
-  uint32_t GetSigPatchConstantElements() const {
+  uint32_t GetSigPatchConstOrPrimElements() const {
     if (m_pPSVRuntimeInfo1)
-      return m_pPSVRuntimeInfo1->SigPatchConstantElements;
+      return m_pPSVRuntimeInfo1->SigPatchConstOrPrimElements;
     return 0;
   }
   PSVSignatureElement0* GetInputElement0(uint32_t index) const {
@@ -770,11 +785,11 @@ public:
     }
     return nullptr;
   }
-  PSVSignatureElement0* GetPatchConstantElement0(uint32_t index) const {
-    if (m_pPSVRuntimeInfo1 && m_pSigPatchConstantElements &&
-        index < m_pPSVRuntimeInfo1->SigPatchConstantElements &&
+  PSVSignatureElement0* GetPatchConstOrPrimElement0(uint32_t index) const {
+    if (m_pPSVRuntimeInfo1 && m_pSigPatchConstOrPrimElements &&
+        index < m_pPSVRuntimeInfo1->SigPatchConstOrPrimElements &&
         sizeof(PSVSignatureElement0) <= m_uPSVSignatureElementSize) {
-      return (PSVSignatureElement0*)((uint8_t*)m_pSigPatchConstantElements +
+      return (PSVSignatureElement0*)((uint8_t*)m_pSigPatchConstOrPrimElements +
         (index * m_uPSVSignatureElementSize));
     }
     return nullptr;
@@ -795,6 +810,8 @@ public:
   bool IsGS() const { return GetShaderKind() == PSVShaderKind::Geometry; }
   bool IsPS() const { return GetShaderKind() == PSVShaderKind::Pixel; }
   bool IsCS() const { return GetShaderKind() == PSVShaderKind::Compute; }
+  bool IsMS() const { return GetShaderKind() == PSVShaderKind::Mesh; }
+  bool IsAS() const { return GetShaderKind() == PSVShaderKind::Amplification; }
 
   // ViewID dependencies
   PSVComponentMask GetViewIDOutputMask(unsigned streamIndex = 0) const {
@@ -803,9 +820,9 @@ public:
     return PSVComponentMask(m_pViewIDOutputMask, m_pPSVRuntimeInfo1->SigOutputVectors[streamIndex]);
   }
   PSVComponentMask GetViewIDPCOutputMask() const {
-    if (!IsHS() || !m_pViewIDPCOutputMask || !m_pPSVRuntimeInfo1 || !m_pPSVRuntimeInfo1->SigPatchConstantVectors)
+    if ((!IsHS() && !IsMS()) || !m_pViewIDPCOrPrimOutputMask || !m_pPSVRuntimeInfo1 || !m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors)
       return PSVComponentMask();
-    return PSVComponentMask(m_pViewIDPCOutputMask, m_pPSVRuntimeInfo1->SigPatchConstantVectors);
+    return PSVComponentMask(m_pViewIDPCOrPrimOutputMask, m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors);
   }
 
   // Input to Output dependencies
@@ -816,14 +833,14 @@ public:
     return PSVDependencyTable();
   }
   PSVDependencyTable GetInputToPCOutputTable() const {
-    if (IsHS() && m_pInputToPCOutputTable && m_pPSVRuntimeInfo1) {
-      return PSVDependencyTable(m_pInputToPCOutputTable, m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigPatchConstantVectors);
+    if ((IsHS() || IsMS()) && m_pInputToPCOutputTable && m_pPSVRuntimeInfo1) {
+      return PSVDependencyTable(m_pInputToPCOutputTable, m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors);
     }
     return PSVDependencyTable();
   }
   PSVDependencyTable GetPCInputToOutputTable() const {
     if (IsDS() && m_pPCInputToOutputTable && m_pPSVRuntimeInfo1) {
-      return PSVDependencyTable(m_pPCInputToOutputTable, m_pPSVRuntimeInfo1->SigPatchConstantVectors, m_pPSVRuntimeInfo1->SigOutputVectors[0]);
+      return PSVDependencyTable(m_pPCInputToOutputTable, m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors, m_pPSVRuntimeInfo1->SigOutputVectors[0]);
     }
     return PSVDependencyTable();
   }

+ 3 - 3
include/dxc/HLSL/ComputeViewIdState.h

@@ -55,15 +55,15 @@ struct DxilViewIdStateData {
 
   unsigned m_NumInputSigScalars  = 0;
   unsigned m_NumOutputSigScalars[kNumStreams] = {0,0,0,0};
-  unsigned m_NumPCSigScalars     = 0;
+  unsigned m_NumPCOrPrimSigScalars     = 0;
 
   // Set of scalar outputs dependent on ViewID.
   OutputsDependentOnViewIdType m_OutputsDependentOnViewId[kNumStreams];
-  OutputsDependentOnViewIdType m_PCOutputsDependentOnViewId;
+  OutputsDependentOnViewIdType m_PCOrPrimOutputsDependentOnViewId;
 
   // Set of scalar inputs contributing to computation of scalar outputs.
   InputsContributingToOutputType m_InputsContributingToOutputs[kNumStreams];
-  InputsContributingToOutputType m_InputsContributingToPCOutputs; // HS PC only.
+  InputsContributingToOutputType m_InputsContributingToPCOrPrimOutputs; // HS PC and MS Prim only.
   InputsContributingToOutputType m_PCInputsContributingToOutputs; // DS only.
 
   bool m_bUsesViewId = false;

+ 17 - 0
include/dxc/HLSL/DxilValidation.h

@@ -85,6 +85,9 @@ enum class ValidationRule : unsigned {
   InstrMinPrecisonBitCast, // Bitcast on minprecison types is not allowed
   InstrMipLevelForGetDimension, // Use mip level on buffer when GetDimensions
   InstrMipOnUAVLoad, // uav load don't support mipLevel/sampleIndex
+  InstrMissingSetMeshOutputCounts, // Missing SetMeshOutputCounts call.
+  InstrMultipleGetMeshPayload, // GetMeshPayload cannot be called multiple times.
+  InstrMultipleSetMeshOutputCounts, // SetMeshOUtputCounts cannot be called multiple times.
   InstrNoGenericPtrAddrSpaceCast, // Address space cast between pointer types must have one part to be generic address space
   InstrNoIDivByZero, // No signed integer division by zero
   InstrNoIndefiniteAcos, // No indefinite arccosine
@@ -93,6 +96,9 @@ enum class ValidationRule : unsigned {
   InstrNoIndefiniteLog, // No indefinite logarithm
   InstrNoReadingUninitialized, // Instructions should not read uninitialized value
   InstrNoUDivByZero, // No unsigned integer division by zero
+  InstrNonDominatingDispatchMesh, // Non-Dominating DispatchMesh call.
+  InstrNonDominatingSetMeshOutputCounts, // Non-Dominating SetMeshOutputCounts call.
+  InstrNotOnceDispatchMesh, // DispatchMesh must be called exactly once in an Amplification shader.
   InstrOffsetOnUAVLoad, // uav load don't support offset
   InstrOload, // DXIL intrinsic overload must be valid
   InstrOnlyOneAllocConsume, // RWStructuredBuffers may increment or decrement their counters, but not both.
@@ -190,6 +196,7 @@ enum class ValidationRule : unsigned {
 
   // Shader model
   Sm64bitRawBufferLoadStore, // i64/f64 rawBufferLoad/Store overloads are allowed after SM 6.3
+  SmAmplificationShaderPayloadSize, // For shader '%0', payload size is greater than %1
   SmAppendAndConsumeOnSameUAV, // BufferUpdateCounter inc and dec on a given UAV (%d) cannot both be in the same shader for shader model less than 5.1.
   SmCBufferArrayOffsetAlignment, // CBuffer array offset must be aligned to 16-bytes
   SmCBufferElementOverflow, // CBuffer elements must not overflow
@@ -197,6 +204,7 @@ enum class ValidationRule : unsigned {
   SmCBufferTemplateTypeMustBeStruct, // D3D12 constant/texture buffer template element can only be a struct
   SmCSNoSignatures, // Compute shaders must not have shader signatures.
   SmCompletePosition, // Not all elements of SV_Position were written
+  SmConstantInterpMode, // Interpolation mode must be constant for MS primitive output.
   SmCounterOnlyOnStructBuf, // BufferUpdateCounter valid only on structured buffers
   SmDSInputControlPointCountRange, // DS input control point count must be [0..%0].  %1 specified
   SmDomainLocationIdxOOB, // DomainLocation component index out of bounds for the domain.
@@ -213,8 +221,17 @@ enum class ValidationRule : unsigned {
   SmInvalidResourceKind, // Invalid resources kind
   SmInvalidTextureKindOnUAV, // Texture2DMS[Array] or TextureCube[Array] resources are not supported with UAVs
   SmIsoLineOutputPrimitiveMismatch, // Hull Shader declared with IsoLine Domain must specify output primitive point or line. Triangle_cw or triangle_ccw output are not compatible with the IsoLine Domain.
+  SmMaxMSSMSize, // Total Thread Group Shared Memory storage is %0, exceeded %1
   SmMaxTGSMSize, // Total Thread Group Shared Memory storage is %0, exceeded %1
   SmMaxTheadGroup, // Declared Thread Group Count %0 (X*Y*Z) is beyond the valid maximum of %1
+  SmMeshPSigRowCount, // For shader '%0', primitive output signatures are taking up more than %1 rows
+  SmMeshShaderInOutSize, // For shader '%0', input plus output size is greater than %1
+  SmMeshShaderMaxPrimitiveCount, // MS max primitive output count must be [0..%0].  %1 specified
+  SmMeshShaderMaxVertexCount, // MS max vertex output count must be [0..%0].  %1 specified
+  SmMeshShaderOutputSize, // For shader '%0', vertex plus primitive output size is greater than %1
+  SmMeshShaderPayloadSize, // For shader '%0', payload size is greater than %1
+  SmMeshTotalSigRowCount, // For shader '%0', vertex and primitive output signatures are taking up more than %1 rows
+  SmMeshVSigRowCount, // For shader '%0', vertex output signatures are taking up more than %1 rows
   SmMultiStreamMustBePoint, // When multiple GS output streams are used they must be pointlists
   SmName, // Target shader model name must be known
   SmNoInterpMode, // Interpolation mode must be undefined for VS input/PS output/patch constant.

+ 17 - 0
include/dxc/HLSL/HLOperations.h

@@ -296,6 +296,14 @@ const unsigned kGatherCmpStatusArgIndex = 6;
 const unsigned kGatherCmpSampleOffsetArgIndex = 6;
 const unsigned kGatherCmpStatusWithSampleOffsetArgIndex = 9;
 
+// WriteSamplerFeedback.
+const unsigned kWriteSamplerFeedbackSampledArgIndex = 2;
+const unsigned kWriteSamplerFeedbackSamplerArgIndex = 3;
+const unsigned kWriteSamplerFeedbackCoordArgIndex = 4;
+const unsigned kWriteSamplerFeedbackBiasOrLodArgIndex = 5;
+const unsigned kWriteSamplerFeedbackDdxArgIndex = 5;
+const unsigned kWriteSamplerFeedbackDdyArgIndex = 6;
+
 // StreamAppend.
 const unsigned kStreamAppendStreamOpIndex = 1;
 const unsigned kStreamAppendDataOpIndex = 2;
@@ -333,9 +341,18 @@ const unsigned kCreateHandleIndexOpIdx = 2; // Only for array of cbuffer.
 const unsigned kTraceRayRayDescOpIdx = 7;
 const unsigned kTraceRayPayLoadOpIdx = 8;
 
+// TraceRayInline.
+const unsigned kTraceRayInlineRayDescOpIdx = 5;
+
 // ReportIntersection.
 const unsigned kReportIntersectionAttributeOpIdx = 3;
 
+// DispatchMesh
+const unsigned kDispatchMeshOpThreadX = 1;
+const unsigned kDispatchMeshOpThreadY = 2;
+const unsigned kDispatchMeshOpThreadZ = 3;
+const unsigned kDispatchMeshOpPayload = 4;
+
 } // namespace HLOperandIndex
 
 llvm::Function *GetOrCreateHLFunction(llvm::Module &M,

+ 4 - 4
include/dxc/HLSL/ViewIDPipelineValidation.inl

@@ -306,9 +306,9 @@ public:
         [&](unsigned i) -> PSVSignatureElement {
         return PSV.GetSignatureElement(PSV.GetOutputElement0(i));
       });
-      CopyElements(pcSig, DXIL::SigPointKind::PCOut, PSV.GetSigPatchConstantElements(), 0,
+      CopyElements(pcSig, DXIL::SigPointKind::PCOut, PSV.GetSigPatchConstOrPrimElements(), 0,
         [&](unsigned i) -> PSVSignatureElement {
-        return PSV.GetSignatureElement(PSV.GetPatchConstantElement0(i));
+        return PSV.GetSignatureElement(PSV.GetPatchConstOrPrimElement0(i));
       });
 
       // Propagate prior mask through input-output dependencies
@@ -345,9 +345,9 @@ public:
                     [&](unsigned i) -> PSVSignatureElement {
                       return PSV.GetSignatureElement(PSV.GetInputElement0(i));
                     });
-      CopyElements( pcSig, DXIL::SigPointKind::DSIn, PSV.GetSigPatchConstantElements(), 0,
+      CopyElements( pcSig, DXIL::SigPointKind::DSIn, PSV.GetSigPatchConstOrPrimElements(), 0,
                     [&](unsigned i) -> PSVSignatureElement {
-                      return PSV.GetSignatureElement(PSV.GetPatchConstantElement0(i));
+                      return PSV.GetSignatureElement(PSV.GetPatchConstOrPrimElement0(i));
                     });
 
       // Merge prior and input signatures, update prior mask size if necessary

+ 45 - 0
include/dxc/HlslIntrinsicOp.h

@@ -30,11 +30,13 @@ import hctdb_instrhelp
   IOP_D3DCOLORtoUBYTE4,
   IOP_DeviceMemoryBarrier,
   IOP_DeviceMemoryBarrierWithGroupSync,
+  IOP_DispatchMesh,
   IOP_DispatchRaysDimensions,
   IOP_DispatchRaysIndex,
   IOP_EvaluateAttributeAtSample,
   IOP_EvaluateAttributeCentroid,
   IOP_EvaluateAttributeSnapped,
+  IOP_GeometryIndex,
   IOP_GetAttributeAtVertex,
   IOP_GetRenderTargetSampleCount,
   IOP_GetRenderTargetSamplePosition,
@@ -78,6 +80,7 @@ import hctdb_instrhelp
   IOP_RayTCurrent,
   IOP_RayTMin,
   IOP_ReportHit,
+  IOP_SetMeshOutputCounts,
   IOP_TraceRay,
   IOP_WaveActiveAllEqual,
   IOP_WaveActiveAllTrue,
@@ -260,6 +263,48 @@ import hctdb_instrhelp
   MOP_DecrementCounter,
   MOP_IncrementCounter,
   MOP_Consume,
+  MOP_WriteSamplerFeedback,
+  MOP_WriteSamplerFeedbackBias,
+  MOP_WriteSamplerFeedbackGrad,
+  MOP_WriteSamplerFeedbackLevel,
+  MOP_Abort,
+  MOP_CandidateGeometryIndex,
+  MOP_CandidateInstanceID,
+  MOP_CandidateInstanceIndex,
+  MOP_CandidateObjectRayDirection,
+  MOP_CandidateObjectRayOrigin,
+  MOP_CandidateObjectToWorld3x4,
+  MOP_CandidateObjectToWorld4x3,
+  MOP_CandidatePrimitiveIndex,
+  MOP_CandidateProceduralPrimitiveNonOpaque,
+  MOP_CandidateTriangleBarycentrics,
+  MOP_CandidateTriangleFrontFace,
+  MOP_CandidateTriangleRayT,
+  MOP_CandidateType,
+  MOP_CandidateWorldToObject3x4,
+  MOP_CandidateWorldToObject4x3,
+  MOP_CommitNonOpaqueTriangleHit,
+  MOP_CommitProceduralPrimitiveHit,
+  MOP_CommittedGeometryIndex,
+  MOP_CommittedInstanceID,
+  MOP_CommittedInstanceIndex,
+  MOP_CommittedObjectRayDirection,
+  MOP_CommittedObjectRayOrigin,
+  MOP_CommittedObjectToWorld3x4,
+  MOP_CommittedObjectToWorld4x3,
+  MOP_CommittedPrimitiveIndex,
+  MOP_CommittedRayT,
+  MOP_CommittedStatus,
+  MOP_CommittedTriangleBarycentrics,
+  MOP_CommittedTriangleFrontFace,
+  MOP_CommittedWorldToObject3x4,
+  MOP_CommittedWorldToObject4x3,
+  MOP_Proceed,
+  MOP_RayFlags,
+  MOP_RayTMin,
+  MOP_TraceRayInline,
+  MOP_WorldRayDirection,
+  MOP_WorldRayOrigin,
 #ifdef ENABLE_SPIRV_CODEGEN
   MOP_SubpassLoad,
 #endif // ENABLE_SPIRV_CODEGEN

+ 1 - 1
include/dxc/Support/HLSLOptions.td

@@ -291,7 +291,7 @@ def Oconfig : CommaJoined<["-"], "Oconfig=">, Group<spirv_Group>, Flags<[CoreOpt
 // fxc-based flags that don't match those previously defined.
 
 def target_profile : JoinedOrSeparate<["-", "/"], "T">, Flags<[CoreOption]>, Group<hlslcomp_Group>, MetaVarName<"<profile>">,
-  HelpText<"Set target profile. \n\t<profile>: ps_6_0, ps_6_1, ps_6_2, ps_6_3, ps_6_4, ps_6_5, \n\t\t vs_6_0, vs_6_1, vs_6_2, vs_6_3, vs_6_4, vs_6_5, \n\t\t cs_6_0, cs_6_1, cs_6_2, cs_6_3, cs_6_4, cs_6_5, \n\t\t gs_6_0, gs_6_1, gs_6_2, gs_6_3, gs_6_4, gs_6_5, \n\t\t ds_6_0, ds_6_1, ds_6_2, ds_6_3, ds_6_4, ds_6_5, \n\t\t hs_6_0, hs_6_1, hs_6_2, hs_6_3, hs_6_4, hs_6_5, \n\t\t lib_6_3, lib_6_4, lib_6_5">;
+  HelpText<"Set target profile. \n\t<profile>: ps_6_0, ps_6_1, ps_6_2, ps_6_3, ps_6_4, ps_6_5, \n\t\t vs_6_0, vs_6_1, vs_6_2, vs_6_3, vs_6_4, vs_6_5, \n\t\t cs_6_0, cs_6_1, cs_6_2, cs_6_3, cs_6_4, cs_6_5, \n\t\t gs_6_0, gs_6_1, gs_6_2, gs_6_3, gs_6_4, gs_6_5, \n\t\t ds_6_0, ds_6_1, ds_6_2, ds_6_3, ds_6_4, ds_6_5, \n\t\t hs_6_0, hs_6_1, hs_6_2, hs_6_3, hs_6_4, hs_6_5, \n\t\t lib_6_3, lib_6_4, lib_6_5, ms_6_5, as_6_5">;
 def entrypoint :  JoinedOrSeparate<["-", "/"], "E">, Flags<[CoreOption]>, Group<hlslcomp_Group>,
   HelpText<"Entry point name">;
 // /I <include> - already defined above

+ 1 - 0
include/dxc/Support/SPIRVOptions.h

@@ -56,6 +56,7 @@ struct SpirvCodeGenOptions {
   SpirvLayoutRule cBufferLayoutRule;
   SpirvLayoutRule sBufferLayoutRule;
   SpirvLayoutRule tBufferLayoutRule;
+  SpirvLayoutRule ampPayloadLayoutRule;
   llvm::StringRef stageIoOrder;
   llvm::StringRef targetEnv;
   llvm::SmallVector<int32_t, 4> bShift;

+ 4 - 1
include/dxc/dxcapi.internal.h

@@ -87,7 +87,10 @@ enum LEGAL_INTRINSIC_COMPTYPES {
   LICOMPTYPE_ACCELERATION_STRUCT = 31,
   LICOMPTYPE_USER_DEFINED_TYPE = 32,
 
-  LICOMPTYPE_COUNT = 33
+  LICOMPTYPE_TEXTURE2D = 33,
+  LICOMPTYPE_TEXTURE2DARRAY = 34,
+
+  LICOMPTYPE_COUNT = 35
 };
 
 static const BYTE IA_SPECIAL_BASE = 0xf0;

+ 201 - 7
lib/DXIL/DxilMetadataHelper.cpp

@@ -318,13 +318,13 @@ MDTuple *DxilMDHelper::EmitDxilSignatures(const DxilEntrySignature &EntrySig) {
 
   const DxilSignature &InputSig = EntrySig.InputSignature;
   const DxilSignature &OutputSig = EntrySig.OutputSignature;
-  const DxilSignature &PCSig = EntrySig.PatchConstantSignature;
+  const DxilSignature &PCPSig = EntrySig.PatchConstOrPrimSignature;
 
-  if (!InputSig.GetElements().empty() || !OutputSig.GetElements().empty() || !PCSig.GetElements().empty()) {
+  if (!InputSig.GetElements().empty() || !OutputSig.GetElements().empty() || !PCPSig.GetElements().empty()) {
     Metadata *MDVals[kDxilNumSignatureFields];
     MDVals[kDxilInputSignature]         = EmitSignatureMetadata(InputSig);
     MDVals[kDxilOutputSignature]        = EmitSignatureMetadata(OutputSig);
-    MDVals[kDxilPatchConstantSignature] = EmitSignatureMetadata(PCSig);
+    MDVals[kDxilPatchConstantSignature] = EmitSignatureMetadata(PCPSig);
 
     pSignatureTupleMD = MDNode::get(m_Ctx, MDVals);
   }
@@ -354,14 +354,14 @@ void DxilMDHelper::LoadDxilSignatures(const MDOperand &MDO, DxilEntrySignature &
     return;
   DxilSignature &InputSig = EntrySig.InputSignature;
   DxilSignature &OutputSig = EntrySig.OutputSignature;
-  DxilSignature &PCSig = EntrySig.PatchConstantSignature;
+  DxilSignature &PCPSig = EntrySig.PatchConstOrPrimSignature;
   const MDTuple *pTupleMD = dyn_cast<MDTuple>(MDO.get());
   IFTBOOL(pTupleMD != nullptr, DXC_E_INCORRECT_DXIL_METADATA);
   IFTBOOL(pTupleMD->getNumOperands() == kDxilNumSignatureFields, DXC_E_INCORRECT_DXIL_METADATA);
 
   LoadSignatureMetadata(pTupleMD->getOperand(kDxilInputSignature),         InputSig);
   LoadSignatureMetadata(pTupleMD->getOperand(kDxilOutputSignature),        OutputSig);
-  LoadSignatureMetadata(pTupleMD->getOperand(kDxilPatchConstantSignature), PCSig);
+  LoadSignatureMetadata(pTupleMD->getOperand(kDxilPatchConstantSignature), PCPSig);
 }
 
 MDTuple *DxilMDHelper::EmitSignatureMetadata(const DxilSignature &Sig) {
@@ -774,13 +774,63 @@ void DxilMDHelper::LoadDxilTypeSystem(DxilTypeSystem &TypeSystem) {
   }
 }
 
+Metadata *DxilMDHelper::EmitDxilTemplateArgAnnotation(const DxilTemplateArgAnnotation &annotation) {
+  SmallVector<Metadata *, 2> MDVals;
+  if (annotation.IsType()) {
+    MDVals.emplace_back(Uint32ToConstMD(DxilMDHelper::kDxilTemplateArgTypeTag));
+    MDVals.emplace_back(ValueAsMetadata::get(UndefValue::get(const_cast<Type*>(annotation.GetType()))));
+  } else if (annotation.IsIntegral()) {
+    MDVals.emplace_back(Uint32ToConstMD(DxilMDHelper::kDxilTemplateArgIntegralTag));
+    MDVals.emplace_back(Uint64ToConstMD((uint64_t)annotation.GetIntegral()));
+  }
+  return MDNode::get(m_Ctx, MDVals);
+}
+void DxilMDHelper::LoadDxilTemplateArgAnnotation(const llvm::MDOperand &MDO, DxilTemplateArgAnnotation &annotation) {
+  IFTBOOL(MDO.get() != nullptr, DXC_E_INCORRECT_DXIL_METADATA);
+  const MDTuple *pTupleMD = dyn_cast<MDTuple>(MDO.get());
+  IFTBOOL(pTupleMD != nullptr, DXC_E_INCORRECT_DXIL_METADATA);
+  IFTBOOL(pTupleMD->getNumOperands() >= 1, DXC_E_INCORRECT_DXIL_METADATA);
+  unsigned Tag = ConstMDToUint32(pTupleMD->getOperand(0));
+  switch (Tag) {
+  case kDxilTemplateArgTypeTag:
+    IFTBOOL(pTupleMD->getNumOperands() == 2, DXC_E_INCORRECT_DXIL_METADATA);
+    annotation.SetType(MetadataAsValue::get(m_Ctx,
+      pTupleMD->getOperand(kDxilTemplateArgValue))->getType());
+    break;
+  case kDxilTemplateArgIntegralTag:
+    IFTBOOL(pTupleMD->getNumOperands() == 2, DXC_E_INCORRECT_DXIL_METADATA);
+    annotation.SetIntegral((int64_t)ConstMDToUint64(pTupleMD->getOperand(kDxilTemplateArgValue)));
+    break;
+  }
+}
+
 Metadata *DxilMDHelper::EmitDxilStructAnnotation(const DxilStructAnnotation &SA) {
-  vector<Metadata *> MDVals(SA.GetNumFields() + 1);
+  unsigned valMajor = 0, valMinor = 0;
+  if (m_pSM)
+    m_pSM->GetMinValidatorVersion(valMajor, valMinor);
+  bool bSupportExtended = !(valMajor == 1 && valMinor < 5);
+
+  vector<Metadata *> MDVals;
+  MDVals.reserve(SA.GetNumFields() + 2);  // In case of extended 1.5 property list
+  MDVals.resize(SA.GetNumFields() + 1);
+
   MDVals[0] = Uint32ToConstMD(SA.GetCBufferSize());
   for (unsigned i = 0; i < SA.GetNumFields(); i++) {
     MDVals[i+1] = EmitDxilFieldAnnotation(SA.GetFieldAnnotation(i));
   }
 
+  // Only add template args if shader target requires validator version that supports them.
+  if (bSupportExtended && SA.GetNumTemplateArgs()) {
+    vector<Metadata *> MDTemplateArgs(SA.GetNumTemplateArgs());
+    for (unsigned i = 0; i < SA.GetNumTemplateArgs(); ++i) {
+      MDTemplateArgs[i] = EmitDxilTemplateArgAnnotation(SA.GetTemplateArgAnnotation(i));
+    }
+    SmallVector<Metadata *, 2> MDExtraVals;
+    MDExtraVals.emplace_back(Uint32ToConstMD(DxilMDHelper::kDxilTemplateArgumentsTag));
+    MDExtraVals.emplace_back(MDNode::get(m_Ctx, MDTemplateArgs));
+    MDVals.emplace_back(MDNode::get(m_Ctx, MDExtraVals));
+  }
+
   return MDNode::get(m_Ctx, MDVals);
 }
 
@@ -791,7 +841,28 @@ void DxilMDHelper::LoadDxilStructAnnotation(const MDOperand &MDO, DxilStructAnno
   if (pTupleMD->getNumOperands() == 1) {
     SA.MarkEmptyStruct();
   }
-  IFTBOOL(pTupleMD->getNumOperands() == SA.GetNumFields()+1, DXC_E_INCORRECT_DXIL_METADATA);
+  unsigned valMajor = 0, valMinor = 0;
+  if (m_pSM)
+    m_pSM->GetMinValidatorVersion(valMajor, valMinor);
+  if (!(valMajor == 1 && valMinor < 5) &&
+      (pTupleMD->getNumOperands() == SA.GetNumFields()+2)) {
+    // Load template args from extended operand
+    const MDOperand &MDOExtra = pTupleMD->getOperand(SA.GetNumFields()+1);
+    const MDTuple *pTupleMDExtra = dyn_cast_or_null<MDTuple>(MDOExtra.get());
+    if(pTupleMDExtra) {
+      IFTBOOL(pTupleMDExtra->getNumOperands() % 2 == 0, DXC_E_INCORRECT_DXIL_METADATA);
+      unsigned Tag = ConstMDToUint32(pTupleMDExtra->getOperand(0));
+      IFTBOOL(Tag == kDxilTemplateArgumentsTag, DXC_E_INCORRECT_DXIL_METADATA); // Only one allowed at this point
+      const MDTuple *pTupleTemplateArgs = dyn_cast_or_null<MDTuple>(pTupleMDExtra->getOperand(1).get());
+      IFTBOOL(pTupleTemplateArgs, DXC_E_INCORRECT_DXIL_METADATA);
+      SA.SetNumTemplateArgs(pTupleTemplateArgs->getNumOperands());
+      for (unsigned i = 0; i < pTupleTemplateArgs->getNumOperands(); ++i) {
+        LoadDxilTemplateArgAnnotation(pTupleTemplateArgs->getOperand(i), SA.GetTemplateArgAnnotation(i));
+      }
+    }
+  } else {
+    IFTBOOL(pTupleMD->getNumOperands() == SA.GetNumFields()+1, DXC_E_INCORRECT_DXIL_METADATA);
+  }
 
   SA.SetCBufferSize(ConstMDToUint32(pTupleMD->getOperand(0)));
   for (unsigned i = 0; i < SA.GetNumFields(); i++) {
@@ -1024,6 +1095,28 @@ const Function *DxilMDHelper::LoadDxilFunctionProps(const MDTuple *pProps,
       props->ShaderProps.Ray.attributeSizeInBytes =
         ConstMDToUint32(pProps->getOperand(idx++));
     break;
+  case DXIL::ShaderKind::Mesh:
+    props->ShaderProps.MS.numThreads[0] =
+      ConstMDToUint32(pProps->getOperand(idx++));
+    props->ShaderProps.MS.numThreads[1] =
+      ConstMDToUint32(pProps->getOperand(idx++));
+    props->ShaderProps.MS.numThreads[2] =
+      ConstMDToUint32(pProps->getOperand(idx++));
+    props->ShaderProps.MS.maxVertexCount =
+      ConstMDToUint32(pProps->getOperand(idx++));
+    props->ShaderProps.MS.maxPrimitiveCount =
+      ConstMDToUint32(pProps->getOperand(idx++));
+    props->ShaderProps.MS.outputTopology =
+      (DXIL::MeshOutputTopology)ConstMDToUint32(pProps->getOperand(idx++));
+    break;
+  case DXIL::ShaderKind::Amplification:
+    props->ShaderProps.AS.numThreads[0] =
+      ConstMDToUint32(pProps->getOperand(idx++));
+    props->ShaderProps.AS.numThreads[1] =
+      ConstMDToUint32(pProps->getOperand(idx++));
+    props->ShaderProps.AS.numThreads[2] =
+      ConstMDToUint32(pProps->getOperand(idx++));
+    break;
   default:
     break;
   }
@@ -1123,6 +1216,21 @@ MDTuple *DxilMDHelper::EmitDxilEntryProperties(uint64_t rawShaderFlag,
     MDVals.emplace_back(
         Uint32ToConstMD(props.ShaderProps.Ray.payloadSizeInBytes));
   } break;
+  case DXIL::ShaderKind::Mesh: {
+    auto &MS = props.ShaderProps.MS;
+    MDVals.emplace_back(Uint32ToConstMD(DxilMDHelper::kDxilMSStateTag));
+    MDTuple *pMDTuple = EmitDxilMSState(MS.numThreads,
+                                        MS.maxVertexCount,
+                                        MS.maxPrimitiveCount,
+                                        MS.outputTopology);
+    MDVals.emplace_back(pMDTuple);
+  } break;
+  case DXIL::ShaderKind::Amplification: {
+    auto &AS = props.ShaderProps.AS;
+    MDVals.emplace_back(Uint32ToConstMD(DxilMDHelper::kDxilASStateTag));
+    MDTuple *pMDTuple = EmitDxilASState(AS.numThreads);
+    MDVals.emplace_back(pMDTuple);
+  } break;
   default:
     break;
   }
@@ -1239,6 +1347,17 @@ void DxilMDHelper::LoadDxilEntryProperties(const MDOperand &MDO,
                "else invalid shader kind");
       props.shaderKind = kind;
     } break;
+    case DxilMDHelper::kDxilMSStateTag: {
+      DXASSERT(props.IsMS(), "else invalid shader kind");
+      auto &MS = props.ShaderProps.MS;
+      LoadDxilMSState(MDO, MS.numThreads, MS.maxVertexCount,
+                      MS.maxPrimitiveCount, MS.outputTopology);
+    } break;
+    case DxilMDHelper::kDxilASStateTag: {
+      DXASSERT(props.IsAS(), "else invalid shader kind");
+      auto &AS = props.ShaderProps.AS;
+      LoadDxilASState(MDO, AS.numThreads);
+    } break;
     default:
       DXASSERT(false, "Unknown extended shader properties tag");
       break;
@@ -1307,6 +1426,20 @@ DxilMDHelper::EmitDxilFunctionProps(const hlsl::DxilFunctionProps *props,
     if (bRayAttributes)
       MDVals[valIdx++] = Uint32ToConstMD(props->ShaderProps.Ray.attributeSizeInBytes);
     break;
+  case DXIL::ShaderKind::Mesh:
+    MDVals[valIdx++] = Uint32ToConstMD(props->ShaderProps.MS.numThreads[0]);
+    MDVals[valIdx++] = Uint32ToConstMD(props->ShaderProps.MS.numThreads[1]);
+    MDVals[valIdx++] = Uint32ToConstMD(props->ShaderProps.MS.numThreads[2]);
+    MDVals[valIdx++] = Uint32ToConstMD(props->ShaderProps.MS.maxVertexCount);
+    MDVals[valIdx++] = Uint32ToConstMD(props->ShaderProps.MS.maxPrimitiveCount);
+    MDVals[valIdx++] =
+        Uint8ToConstMD((uint8_t)props->ShaderProps.MS.outputTopology);
+    break;
+  case DXIL::ShaderKind::Amplification:
+    MDVals[valIdx++] = Uint32ToConstMD(props->ShaderProps.AS.numThreads[0]);
+    MDVals[valIdx++] = Uint32ToConstMD(props->ShaderProps.AS.numThreads[1]);
+    MDVals[valIdx++] = Uint32ToConstMD(props->ShaderProps.AS.numThreads[2]);
+    break;
   default:
     break;
   }
@@ -1768,6 +1901,67 @@ void DxilMDHelper::LoadDxilHSState(const MDOperand &MDO,
   MaxTessFactor           = ConstMDToFloat(pTupleMD->getOperand(kDxilHSStateMaxTessellationFactor));
 }
 
+MDTuple *DxilMDHelper::EmitDxilMSState(const unsigned *NumThreads,
+                                       unsigned MaxVertexCount,
+                                       unsigned MaxPrimitiveCount,
+                                       DXIL::MeshOutputTopology OutputTopology) {
+  Metadata *MDVals[kDxilMSStateNumFields];
+  vector<Metadata *> NumThreadVals;
+
+  NumThreadVals.emplace_back(Uint32ToConstMD(NumThreads[0]));
+  NumThreadVals.emplace_back(Uint32ToConstMD(NumThreads[1]));
+  NumThreadVals.emplace_back(Uint32ToConstMD(NumThreads[2]));
+  MDVals[kDxilMSStateNumThreads] = MDNode::get(m_Ctx, NumThreadVals);
+  MDVals[kDxilMSStateMaxVertexCount] = Uint32ToConstMD(MaxVertexCount);
+  MDVals[kDxilMSStateMaxPrimitiveCount] = Uint32ToConstMD(MaxPrimitiveCount);
+  MDVals[kDxilMSStateOutputTopology] = Uint32ToConstMD((unsigned)OutputTopology);
+
+  return MDNode::get(m_Ctx, MDVals);
+}
+
+void DxilMDHelper::LoadDxilMSState(const MDOperand &MDO,
+                                   unsigned *NumThreads,
+                                   unsigned &MaxVertexCount,
+                                   unsigned &MaxPrimitiveCount,
+                                   DXIL::MeshOutputTopology &OutputTopology) {
+  IFTBOOL(MDO.get() != nullptr, DXC_E_INCORRECT_DXIL_METADATA);
+  const MDTuple *pTupleMD = dyn_cast<MDTuple>(MDO.get());
+  IFTBOOL(pTupleMD != nullptr, DXC_E_INCORRECT_DXIL_METADATA);
+  IFTBOOL(pTupleMD->getNumOperands() == kDxilMSStateNumFields, DXC_E_INCORRECT_DXIL_METADATA);
+
+  MDNode *pNode = cast<MDNode>(pTupleMD->getOperand(kDxilMSStateNumThreads));
+  NumThreads[0] = ConstMDToUint32(pNode->getOperand(0));
+  NumThreads[1] = ConstMDToUint32(pNode->getOperand(1));
+  NumThreads[2] = ConstMDToUint32(pNode->getOperand(2));
+  MaxVertexCount = ConstMDToUint32(pTupleMD->getOperand(kDxilMSStateMaxVertexCount));
+  MaxPrimitiveCount = ConstMDToUint32(pTupleMD->getOperand(kDxilMSStateMaxPrimitiveCount));
+  OutputTopology = (DXIL::MeshOutputTopology)ConstMDToUint32(pTupleMD->getOperand(kDxilMSStateOutputTopology));
+}
+
+MDTuple *DxilMDHelper::EmitDxilASState(const unsigned *NumThreads) {
+  Metadata *MDVals[kDxilASStateNumFields];
+  vector<Metadata *> NumThreadVals;
+
+  NumThreadVals.emplace_back(Uint32ToConstMD(NumThreads[0]));
+  NumThreadVals.emplace_back(Uint32ToConstMD(NumThreads[1]));
+  NumThreadVals.emplace_back(Uint32ToConstMD(NumThreads[2]));
+  MDVals[kDxilASStateNumThreads] = MDNode::get(m_Ctx, NumThreadVals);
+
+  return MDNode::get(m_Ctx, MDVals);
+}
+
+void DxilMDHelper::LoadDxilASState(const MDOperand &MDO, unsigned *NumThreads) {
+  IFTBOOL(MDO.get() != nullptr, DXC_E_INCORRECT_DXIL_METADATA);
+  const MDTuple *pTupleMD = dyn_cast<MDTuple>(MDO.get());
+  IFTBOOL(pTupleMD != nullptr, DXC_E_INCORRECT_DXIL_METADATA);
+  IFTBOOL(pTupleMD->getNumOperands() == kDxilASStateNumFields, DXC_E_INCORRECT_DXIL_METADATA);
+
+  MDNode *pNode = cast<MDNode>(pTupleMD->getOperand(kDxilASStateNumThreads));
+  NumThreads[0] = ConstMDToUint32(pNode->getOperand(0));
+  NumThreads[1] = ConstMDToUint32(pNode->getOperand(1));
+  NumThreads[2] = ConstMDToUint32(pNode->getOperand(2));
+}
+
 //
 // DxilExtraPropertyHelper methods.
 //

+ 74 - 4
lib/DXIL/DxilModule.cpp

@@ -627,6 +627,74 @@ void DxilModule::SetMaxTessellationFactor(float MaxTessellationFactor) {
   props.ShaderProps.HS.maxTessFactor = MaxTessellationFactor;
 }
 
+unsigned DxilModule::GetMaxOutputVertices() const {
+  if (!m_pSM->IsMS())
+    return 0;
+  DXASSERT(m_DxilEntryPropsMap.size() == 1, "should have one entry prop");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsMS(), "Must be MS profile");
+  return props.ShaderProps.MS.maxVertexCount;
+}
+
+void DxilModule::SetMaxOutputVertices(unsigned NumOVs) {
+  DXASSERT(m_DxilEntryPropsMap.size() == 1 && m_pSM->IsMS(),
+           "only works for MS profile");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsMS(), "Must be MS profile");
+  props.ShaderProps.MS.maxVertexCount = NumOVs;
+}
+
+unsigned DxilModule::GetMaxOutputPrimitives() const {
+  if (!m_pSM->IsMS())
+    return 0;
+  DXASSERT(m_DxilEntryPropsMap.size() == 1, "should have one entry prop");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsMS(), "Must be MS profile");
+  return props.ShaderProps.MS.maxPrimitiveCount;
+}
+
+void DxilModule::SetMaxOutputPrimitives(unsigned NumOPs) {
+  DXASSERT(m_DxilEntryPropsMap.size() == 1 && m_pSM->IsMS(),
+           "only works for MS profile");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsMS(), "Must be MS profile");
+  props.ShaderProps.MS.maxPrimitiveCount = NumOPs;
+}
+
+DXIL::MeshOutputTopology DxilModule::GetMeshOutputTopology() const {
+  if (!m_pSM->IsMS())
+    return DXIL::MeshOutputTopology::Undefined;
+  DXASSERT(m_DxilEntryPropsMap.size() == 1, "should have one entry prop");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsMS(), "Must be MS profile");
+  return props.ShaderProps.MS.outputTopology;
+}
+
+void DxilModule::SetMeshOutputTopology(DXIL::MeshOutputTopology MeshOutputTopology) {
+  DXASSERT(m_DxilEntryPropsMap.size() == 1 && m_pSM->IsMS(),
+           "only works for MS profile");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsMS(), "Must be MS profile");
+  props.ShaderProps.MS.outputTopology = MeshOutputTopology;
+}
+
+unsigned DxilModule::GetPayloadByteSize() const {
+  if (!m_pSM->IsMS())
+    return 0;
+  DXASSERT(m_DxilEntryPropsMap.size() == 1, "should have one entry prop");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsMS(), "Must be MS profile");
+  return props.ShaderProps.MS.payloadByteSize;
+}
+
+void DxilModule::SetPayloadByteSize(unsigned Size) {
+  DXASSERT(m_DxilEntryPropsMap.size() == 1 && m_pSM->IsMS(),
+           "only works for MS profile");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsMS(), "Must be MS profile");
+  props.ShaderProps.MS.payloadByteSize = Size;
+}
+
 void DxilModule::SetAutoBindingSpace(uint32_t Space) {
   m_AutoBindingSpace = Space;
 }
@@ -651,6 +719,8 @@ void DxilModule::SetShaderProperties(DxilFunctionProps *props) {
   case DXIL::ShaderKind::Domain:
   case DXIL::ShaderKind::Hull:
   case DXIL::ShaderKind::Vertex:
+  case DXIL::ShaderKind::Mesh:
+  case DXIL::ShaderKind::Amplification:
     break;
   default: {
     DXASSERT(props->shaderKind == DXIL::ShaderKind::Geometry,
@@ -929,16 +999,16 @@ const DxilSignature &DxilModule::GetOutputSignature() const {
   return m_DxilEntryPropsMap.begin()->second->sig.OutputSignature;
 }
 
-DxilSignature &DxilModule::GetPatchConstantSignature() {
+DxilSignature &DxilModule::GetPatchConstOrPrimSignature() {
   DXASSERT(m_DxilEntryPropsMap.size() == 1 && !m_pSM->IsLib(),
            "only works for non-lib profile");
-  return m_DxilEntryPropsMap.begin()->second->sig.PatchConstantSignature;
+  return m_DxilEntryPropsMap.begin()->second->sig.PatchConstOrPrimSignature;
 }
 
-const DxilSignature &DxilModule::GetPatchConstantSignature() const {
+const DxilSignature &DxilModule::GetPatchConstOrPrimSignature() const {
   DXASSERT(m_DxilEntryPropsMap.size() == 1 && !m_pSM->IsLib(),
            "only works for non-lib profile");
-  return m_DxilEntryPropsMap.begin()->second->sig.PatchConstantSignature;
+  return m_DxilEntryPropsMap.begin()->second->sig.PatchConstOrPrimSignature;
 }
 
 const std::vector<uint8_t> &DxilModule::GetSerializedRootSignature() const {

+ 217 - 6
lib/DXIL/DxilOperations.cpp

@@ -187,7 +187,7 @@ const OP::OpCodeProperty OP::m_OpCodeProps[(unsigned)OP::OpCode::NumOpCodes] = {
   {  OC::Coverage,                "Coverage",                 OCC::Coverage,                 "coverage",                  { false, false, false, false, false, false, false,  true, false, false, false}, Attribute::ReadNone, },
   {  OC::InnerCoverage,           "InnerCoverage",            OCC::InnerCoverage,            "innerCoverage",             { false, false, false, false, false, false, false,  true, false, false, false}, Attribute::ReadNone, },
 
-  // Compute shader                                                                                                          void,     h,     f,     d,    i1,    i8,   i16,   i32,   i64,   udt,   obj ,  function attribute
+  // Compute/Mesh/Amplification shader                                                                                       void,     h,     f,     d,    i1,    i8,   i16,   i32,   i64,   udt,   obj ,  function attribute
   {  OC::ThreadId,                "ThreadId",                 OCC::ThreadId,                 "threadId",                  { false, false, false, false, false, false, false,  true, false, false, false}, Attribute::ReadNone, },
   {  OC::GroupId,                 "GroupId",                  OCC::GroupId,                  "groupId",                   { false, false, false, false, false, false, false,  true, false, false, false}, Attribute::ReadNone, },
   {  OC::ThreadIdInGroup,         "ThreadIdInGroup",          OCC::ThreadIdInGroup,          "threadIdInGroup",           { false, false, false, false, false, false, false,  true, false, false, false}, Attribute::ReadNone, },
@@ -321,6 +321,62 @@ const OP::OpCodeProperty OP::m_OpCodeProps[(unsigned)OP::OpCode::NumOpCodes] = {
   {  OC::WaveMatch,               "WaveMatch",                OCC::WaveMatch,                "waveMatch",                 { false,  true,  true,  true, false,  true,  true,  true,  true, false, false}, Attribute::None,     },
   {  OC::WaveMultiPrefixOp,       "WaveMultiPrefixOp",        OCC::WaveMultiPrefixOp,        "waveMultiPrefixOp",         { false,  true,  true,  true, false,  true,  true,  true,  true, false, false}, Attribute::None,     },
   {  OC::WaveMultiPrefixBitCount, "WaveMultiPrefixBitCount",  OCC::WaveMultiPrefixBitCount,  "waveMultiPrefixBitCount",   {  true, false, false, false, false, false, false, false, false, false, false}, Attribute::None,     },
+
+  // Mesh shader instructions                                                                                                void,     h,     f,     d,    i1,    i8,   i16,   i32,   i64,   udt,   obj ,  function attribute
+  {  OC::SetMeshOutputCounts,     "SetMeshOutputCounts",      OCC::SetMeshOutputCounts,      "setMeshOutputCounts",       {  true, false, false, false, false, false, false, false, false, false, false}, Attribute::None,     },
+  {  OC::EmitIndices,             "EmitIndices",              OCC::EmitIndices,              "emitIndices",               {  true, false, false, false, false, false, false, false, false, false, false}, Attribute::None,     },
+  {  OC::GetMeshPayload,          "GetMeshPayload",           OCC::GetMeshPayload,           "getMeshPayload",            { false, false, false, false, false, false, false, false, false,  true, false}, Attribute::ReadOnly, },
+  {  OC::StoreVertexOutput,       "StoreVertexOutput",        OCC::StoreVertexOutput,        "storeVertexOutput",         { false,  true,  true, false, false, false,  true,  true, false, false, false}, Attribute::None,     },
+  {  OC::StorePrimitiveOutput,    "StorePrimitiveOutput",     OCC::StorePrimitiveOutput,     "storePrimitiveOutput",      { false,  true,  true, false, false, false,  true,  true, false, false, false}, Attribute::None,     },
+
+  // Amplification shader instructions                                                                                       void,     h,     f,     d,    i1,    i8,   i16,   i32,   i64,   udt,   obj ,  function attribute
+  {  OC::DispatchMesh,            "DispatchMesh",             OCC::DispatchMesh,             "dispatchMesh",              { false, false, false, false, false, false, false, false, false,  true, false}, Attribute::None,     },
+
+  // Sampler Feedback                                                                                                        void,     h,     f,     d,    i1,    i8,   i16,   i32,   i64,   udt,   obj ,  function attribute
+  {  OC::WriteSamplerFeedback,    "WriteSamplerFeedback",     OCC::WriteSamplerFeedback,     "writeSamplerFeedback",      {  true, false, false, false, false, false, false, false, false, false, false}, Attribute::None,     },
+  {  OC::WriteSamplerFeedbackBias, "WriteSamplerFeedbackBias", OCC::WriteSamplerFeedbackBias, "writeSamplerFeedbackBias",  {  true, false, false, false, false, false, false, false, false, false, false}, Attribute::None,     },
+  {  OC::WriteSamplerFeedbackLevel, "WriteSamplerFeedbackLevel", OCC::WriteSamplerFeedbackLevel, "writeSamplerFeedbackLevel", {  true, false, false, false, false, false, false, false, false, false, false}, Attribute::None,     },
+  {  OC::WriteSamplerFeedbackGrad, "WriteSamplerFeedbackGrad", OCC::WriteSamplerFeedbackGrad, "writeSamplerFeedbackGrad",  {  true, false, false, false, false, false, false, false, false, false, false}, Attribute::None,     },
+
+  // Inline Ray Query                                                                                                        void,     h,     f,     d,    i1,    i8,   i16,   i32,   i64,   udt,   obj ,  function attribute
+  {  OC::AllocateRayQuery,        "AllocateRayQuery",         OCC::AllocateRayQuery,         "allocateRayQuery",          {  true, false, false, false, false, false, false, false, false, false, false}, Attribute::ReadNone, },
+  {  OC::RayQuery_TraceRayInline, "RayQuery_TraceRayInline",  OCC::RayQuery_TraceRayInline,  "rayQuery_TraceRayInline",   {  true, false, false, false, false, false, false, false, false, false, false}, Attribute::None,     },
+  {  OC::RayQuery_Proceed,        "RayQuery_Proceed",         OCC::RayQuery_Proceed,         "rayQuery_Proceed",          { false, false, false, false,  true, false, false, false, false, false, false}, Attribute::None,     },
+  {  OC::RayQuery_Abort,          "RayQuery_Abort",           OCC::RayQuery_Abort,           "rayQuery_Abort",            {  true, false, false, false, false, false, false, false, false, false, false}, Attribute::None,     },
+  {  OC::RayQuery_CommitNonOpaqueTriangleHit, "RayQuery_CommitNonOpaqueTriangleHit", OCC::RayQuery_CommitNonOpaqueTriangleHit, "rayQuery_CommitNonOpaqueTriangleHit", {  true, false, false, false, false, false, false, false, false, false, false}, Attribute::None,     },
+  {  OC::RayQuery_CommitProceduralPrimitiveHit, "RayQuery_CommitProceduralPrimitiveHit", OCC::RayQuery_CommitProceduralPrimitiveHit, "rayQuery_CommitProceduralPrimitiveHit", {  true, false, false, false, false, false, false, false, false, false, false}, Attribute::None,     },
+  {  OC::RayQuery_CommittedStatus, "RayQuery_CommittedStatus", OCC::RayQuery_StateScalar,     "rayQuery_StateScalar",      { false, false, false, false, false, false, false,  true, false, false, false}, Attribute::ReadOnly, },
+  {  OC::RayQuery_CandidateType,  "RayQuery_CandidateType",   OCC::RayQuery_StateScalar,     "rayQuery_StateScalar",      { false, false, false, false, false, false, false,  true, false, false, false}, Attribute::ReadOnly, },
+  {  OC::RayQuery_CandidateObjectToWorld3x4, "RayQuery_CandidateObjectToWorld3x4", OCC::RayQuery_StateMatrix,     "rayQuery_StateMatrix",      { false, false,  true, false, false, false, false, false, false, false, false}, Attribute::ReadOnly, },
+  {  OC::RayQuery_CandidateWorldToObject3x4, "RayQuery_CandidateWorldToObject3x4", OCC::RayQuery_StateMatrix,     "rayQuery_StateMatrix",      { false, false,  true, false, false, false, false, false, false, false, false}, Attribute::ReadOnly, },
+  {  OC::RayQuery_CommittedObjectToWorld3x4, "RayQuery_CommittedObjectToWorld3x4", OCC::RayQuery_StateMatrix,     "rayQuery_StateMatrix",      { false, false,  true, false, false, false, false, false, false, false, false}, Attribute::ReadOnly, },
+  {  OC::RayQuery_CommittedWorldToObject3x4, "RayQuery_CommittedWorldToObject3x4", OCC::RayQuery_StateMatrix,     "rayQuery_StateMatrix",      { false, false,  true, false, false, false, false, false, false, false, false}, Attribute::ReadOnly, },
+  {  OC::RayQuery_CandidateProceduralPrimitiveNonOpaque, "RayQuery_CandidateProceduralPrimitiveNonOpaque", OCC::RayQuery_StateScalar,     "rayQuery_StateScalar",      { false, false, false, false,  true, false, false, false, false, false, false}, Attribute::ReadOnly, },
+  {  OC::RayQuery_CandidateTriangleFrontFace, "RayQuery_CandidateTriangleFrontFace", OCC::RayQuery_StateScalar,     "rayQuery_StateScalar",      { false, false, false, false,  true, false, false, false, false, false, false}, Attribute::ReadOnly, },
+  {  OC::RayQuery_CommittedTriangleFrontFace, "RayQuery_CommittedTriangleFrontFace", OCC::RayQuery_StateScalar,     "rayQuery_StateScalar",      { false, false, false, false,  true, false, false, false, false, false, false}, Attribute::ReadOnly, },
+  {  OC::RayQuery_CandidateTriangleBarycentrics, "RayQuery_CandidateTriangleBarycentrics", OCC::RayQuery_StateVector,     "rayQuery_StateVector",      { false, false,  true, false, false, false, false, false, false, false, false}, Attribute::ReadOnly, },
+  {  OC::RayQuery_CommittedTriangleBarycentrics, "RayQuery_CommittedTriangleBarycentrics", OCC::RayQuery_StateVector,     "rayQuery_StateVector",      { false, false,  true, false, false, false, false, false, false, false, false}, Attribute::ReadOnly, },
+  {  OC::RayQuery_RayFlags,       "RayQuery_RayFlags",        OCC::RayQuery_StateScalar,     "rayQuery_StateScalar",      { false, false, false, false, false, false, false,  true, false, false, false}, Attribute::ReadOnly, },
+  {  OC::RayQuery_WorldRayOrigin, "RayQuery_WorldRayOrigin",  OCC::RayQuery_StateVector,     "rayQuery_StateVector",      { false, false,  true, false, false, false, false, false, false, false, false}, Attribute::ReadOnly, },
+  {  OC::RayQuery_WorldRayDirection, "RayQuery_WorldRayDirection", OCC::RayQuery_StateVector,     "rayQuery_StateVector",      { false, false,  true, false, false, false, false, false, false, false, false}, Attribute::ReadOnly, },
+  {  OC::RayQuery_RayTMin,        "RayQuery_RayTMin",         OCC::RayQuery_StateScalar,     "rayQuery_StateScalar",      { false, false,  true, false, false, false, false, false, false, false, false}, Attribute::ReadOnly, },
+  {  OC::RayQuery_CandidateTriangleRayT, "RayQuery_CandidateTriangleRayT", OCC::RayQuery_StateScalar,     "rayQuery_StateScalar",      { false, false,  true, false, false, false, false, false, false, false, false}, Attribute::ReadOnly, },
+  {  OC::RayQuery_CommittedRayT,  "RayQuery_CommittedRayT",   OCC::RayQuery_StateScalar,     "rayQuery_StateScalar",      { false, false,  true, false, false, false, false, false, false, false, false}, Attribute::ReadOnly, },
+  {  OC::RayQuery_CandidateInstanceIndex, "RayQuery_CandidateInstanceIndex", OCC::RayQuery_StateScalar,     "rayQuery_StateScalar",      { false, false, false, false, false, false, false,  true, false, false, false}, Attribute::ReadOnly, },
+  {  OC::RayQuery_CandidateInstanceID, "RayQuery_CandidateInstanceID", OCC::RayQuery_StateScalar,     "rayQuery_StateScalar",      { false, false, false, false, false, false, false,  true, false, false, false}, Attribute::ReadOnly, },
+  {  OC::RayQuery_CandidateGeometryIndex, "RayQuery_CandidateGeometryIndex", OCC::RayQuery_StateScalar,     "rayQuery_StateScalar",      { false, false, false, false, false, false, false,  true, false, false, false}, Attribute::ReadOnly, },
+  {  OC::RayQuery_CandidatePrimitiveIndex, "RayQuery_CandidatePrimitiveIndex", OCC::RayQuery_StateScalar,     "rayQuery_StateScalar",      { false, false, false, false, false, false, false,  true, false, false, false}, Attribute::ReadOnly, },
+  {  OC::RayQuery_CandidateObjectRayOrigin, "RayQuery_CandidateObjectRayOrigin", OCC::RayQuery_StateVector,     "rayQuery_StateVector",      { false, false,  true, false, false, false, false, false, false, false, false}, Attribute::ReadOnly, },
+  {  OC::RayQuery_CandidateObjectRayDirection, "RayQuery_CandidateObjectRayDirection", OCC::RayQuery_StateVector,     "rayQuery_StateVector",      { false, false,  true, false, false, false, false, false, false, false, false}, Attribute::ReadOnly, },
+  {  OC::RayQuery_CommittedInstanceIndex, "RayQuery_CommittedInstanceIndex", OCC::RayQuery_StateScalar,     "rayQuery_StateScalar",      { false, false, false, false, false, false, false,  true, false, false, false}, Attribute::ReadOnly, },
+  {  OC::RayQuery_CommittedInstanceID, "RayQuery_CommittedInstanceID", OCC::RayQuery_StateScalar,     "rayQuery_StateScalar",      { false, false, false, false, false, false, false,  true, false, false, false}, Attribute::ReadOnly, },
+  {  OC::RayQuery_CommittedGeometryIndex, "RayQuery_CommittedGeometryIndex", OCC::RayQuery_StateScalar,     "rayQuery_StateScalar",      { false, false, false, false, false, false, false,  true, false, false, false}, Attribute::ReadOnly, },
+  {  OC::RayQuery_CommittedPrimitiveIndex, "RayQuery_CommittedPrimitiveIndex", OCC::RayQuery_StateScalar,     "rayQuery_StateScalar",      { false, false, false, false, false, false, false,  true, false, false, false}, Attribute::ReadOnly, },
+  {  OC::RayQuery_CommittedObjectRayOrigin, "RayQuery_CommittedObjectRayOrigin", OCC::RayQuery_StateVector,     "rayQuery_StateVector",      { false, false,  true, false, false, false, false, false, false, false, false}, Attribute::ReadOnly, },
+  {  OC::RayQuery_CommittedObjectRayDirection, "RayQuery_CommittedObjectRayDirection", OCC::RayQuery_StateVector,     "rayQuery_StateVector",      { false, false,  true, false, false, false, false, false, false, false, false}, Attribute::ReadOnly, },
+
+  // Raytracing object space uint System Values, raytracing tier 1.1                                                         void,     h,     f,     d,    i1,    i8,   i16,   i32,   i64,   udt,   obj ,  function attribute
+  {  OC::GeometryIndex,           "GeometryIndex",            OCC::GeometryIndex,            "geometryIndex",             { false, false, false, false, false, false, false,  true, false, false, false}, Attribute::ReadNone, },
 };
 // OPCODE-OLOADS:END
 
@@ -418,6 +474,11 @@ const char *OP::GetOpCodeClassName(OpCode opCode) {
   return m_OpCodeProps[(unsigned)opCode].pOpCodeClassName;
 }
 
+llvm::Attribute::AttrKind OP::GetMemAccessAttr(OpCode opCode) {
+  DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes, "otherwise caller passed OOB index");
+  return m_OpCodeProps[(unsigned)opCode].FuncAttr;
+}
+
 bool OP::IsOverloadLegal(OpCode opCode, Type *pType) {
   DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes, "otherwise caller passed OOB index");
   unsigned TypeSlot = GetTypeSlot(pType);
@@ -533,7 +594,7 @@ void OP::GetMinShaderModelAndMask(OpCode C, bool bWithTranslation,
   // Instructions: ThreadId=93, GroupId=94, ThreadIdInGroup=95,
   // FlattenedThreadIdInGroup=96
   if ((93 <= op && op <= 96)) {
-    mask = SFLAG(Compute);
+    mask = SFLAG(Compute) | SFLAG(Mesh) | SFLAG(Amplification);
     return;
   }
   // Instructions: DomainLocation=105
@@ -585,7 +646,7 @@ void OP::GetMinShaderModelAndMask(OpCode C, bool bWithTranslation,
   // Instructions: ViewID=138
   if (op == 138) {
     major = 6;  minor = 1;
-    mask = SFLAG(Vertex) | SFLAG(Hull) | SFLAG(Domain) | SFLAG(Geometry) | SFLAG(Pixel);
+    mask = SFLAG(Vertex) | SFLAG(Hull) | SFLAG(Domain) | SFLAG(Geometry) | SFLAG(Pixel) | SFLAG(Mesh);
     return;
   }
   // Instructions: RawBufferLoad=139, RawBufferStore=140
@@ -657,9 +718,57 @@ void OP::GetMinShaderModelAndMask(OpCode C, bool bWithTranslation,
     return;
   }
   // Instructions: WaveMatch=165, WaveMultiPrefixOp=166,
-  // WaveMultiPrefixBitCount=167
-  if ((165 <= op && op <= 167)) {
+  // WaveMultiPrefixBitCount=167, WriteSamplerFeedbackLevel=176,
+  // WriteSamplerFeedbackGrad=177, AllocateRayQuery=178,
+  // RayQuery_TraceRayInline=179, RayQuery_Proceed=180, RayQuery_Abort=181,
+  // RayQuery_CommitNonOpaqueTriangleHit=182,
+  // RayQuery_CommitProceduralPrimitiveHit=183, RayQuery_CommittedStatus=184,
+  // RayQuery_CandidateType=185, RayQuery_CandidateObjectToWorld3x4=186,
+  // RayQuery_CandidateWorldToObject3x4=187,
+  // RayQuery_CommittedObjectToWorld3x4=188,
+  // RayQuery_CommittedWorldToObject3x4=189,
+  // RayQuery_CandidateProceduralPrimitiveNonOpaque=190,
+  // RayQuery_CandidateTriangleFrontFace=191,
+  // RayQuery_CommittedTriangleFrontFace=192,
+  // RayQuery_CandidateTriangleBarycentrics=193,
+  // RayQuery_CommittedTriangleBarycentrics=194, RayQuery_RayFlags=195,
+  // RayQuery_WorldRayOrigin=196, RayQuery_WorldRayDirection=197,
+  // RayQuery_RayTMin=198, RayQuery_CandidateTriangleRayT=199,
+  // RayQuery_CommittedRayT=200, RayQuery_CandidateInstanceIndex=201,
+  // RayQuery_CandidateInstanceID=202, RayQuery_CandidateGeometryIndex=203,
+  // RayQuery_CandidatePrimitiveIndex=204, RayQuery_CandidateObjectRayOrigin=205,
+  // RayQuery_CandidateObjectRayDirection=206,
+  // RayQuery_CommittedInstanceIndex=207, RayQuery_CommittedInstanceID=208,
+  // RayQuery_CommittedGeometryIndex=209, RayQuery_CommittedPrimitiveIndex=210,
+  // RayQuery_CommittedObjectRayOrigin=211,
+  // RayQuery_CommittedObjectRayDirection=212
+  if ((165 <= op && op <= 167) || (176 <= op && op <= 212)) {
+    major = 6;  minor = 5;
+    return;
+  }
+  // Instructions: DispatchMesh=173
+  if (op == 173) {
+    major = 6;  minor = 5;
+    mask = SFLAG(Amplification);
+    return;
+  }
+  // Instructions: GeometryIndex=213
+  if (op == 213) {
+    major = 6;  minor = 5;
+    mask = SFLAG(Library) | SFLAG(Intersection) | SFLAG(AnyHit) | SFLAG(ClosestHit);
+    return;
+  }
+  // Instructions: WriteSamplerFeedback=174, WriteSamplerFeedbackBias=175
+  if ((174 <= op && op <= 175)) {
     major = 6;  minor = 5;
+    mask = SFLAG(Library) | SFLAG(Pixel);
+    return;
+  }
+  // Instructions: SetMeshOutputCounts=168, EmitIndices=169, GetMeshPayload=170,
+  // StoreVertexOutput=171, StorePrimitiveOutput=172
+  if ((168 <= op && op <= 172)) {
+    major = 6;  minor = 5;
+    mask = SFLAG(Mesh);
     return;
   }
   // OPCODE-SMMASK:END
@@ -928,7 +1037,7 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
   case OpCode::Coverage:               A(pI32);     A(pI32); break;
   case OpCode::InnerCoverage:          A(pI32);     A(pI32); break;
 
-    // Compute shader
+    // Compute/Mesh/Amplification shader
   case OpCode::ThreadId:               A(pI32);     A(pI32); A(pI32); break;
   case OpCode::GroupId:                A(pI32);     A(pI32); A(pI32); break;
   case OpCode::ThreadIdInGroup:        A(pI32);     A(pI32); A(pI32); break;
@@ -1062,6 +1171,62 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
   case OpCode::WaveMatch:              A(pI4S);     A(pI32); A(pETy); break;
   case OpCode::WaveMultiPrefixOp:      A(pETy);     A(pI32); A(pETy); A(pI32); A(pI32); A(pI32); A(pI32); A(pI8);  A(pI8);  break;
   case OpCode::WaveMultiPrefixBitCount:A(pI32);     A(pI32); A(pI1);  A(pI32); A(pI32); A(pI32); A(pI32); break;
+
+    // Mesh shader instructions
+  case OpCode::SetMeshOutputCounts:    A(pV);       A(pI32); A(pI32); A(pI32); break;
+  case OpCode::EmitIndices:            A(pV);       A(pI32); A(pI32); A(pI32); A(pI32); A(pI32); break;
+  case OpCode::GetMeshPayload:         A(pETy);     A(pI32); break;
+  case OpCode::StoreVertexOutput:      A(pV);       A(pI32); A(pI32); A(pI32); A(pI8);  A(pETy); A(pI32); break;
+  case OpCode::StorePrimitiveOutput:   A(pV);       A(pI32); A(pI32); A(pI32); A(pI8);  A(pETy); A(pI32); break;
+
+    // Amplification shader instructions
+  case OpCode::DispatchMesh:           A(pV);       A(pI32); A(pI32); A(pI32); A(pI32); A(pETy); break;
+
+    // Sampler Feedback
+  case OpCode::WriteSamplerFeedback:   A(pV);       A(pI32); A(pRes); A(pRes); A(pRes); A(pF32); A(pF32); A(pF32); A(pF32); break;
+  case OpCode::WriteSamplerFeedbackBias:A(pV);       A(pI32); A(pRes); A(pRes); A(pRes); A(pF32); A(pF32); A(pF32); A(pF32); A(pF32); break;
+  case OpCode::WriteSamplerFeedbackLevel:A(pV);       A(pI32); A(pRes); A(pRes); A(pRes); A(pF32); A(pF32); A(pF32); A(pF32); break;
+  case OpCode::WriteSamplerFeedbackGrad:A(pV);       A(pI32); A(pRes); A(pRes); A(pRes); A(pF32); A(pF32); A(pF32); A(pF32); A(pF32); A(pF32); break;
+
+    // Inline Ray Query
+  case OpCode::AllocateRayQuery:       A(pI32);     A(pI32); A(pI32); break;
+  case OpCode::RayQuery_TraceRayInline:A(pV);       A(pI32); A(pI32); A(pRes); A(pI32); A(pI32); A(pF32); A(pF32); A(pF32); A(pF32); A(pF32); A(pF32); A(pF32); A(pF32); break;
+  case OpCode::RayQuery_Proceed:       A(pI1);      A(pI32); A(pI32); break;
+  case OpCode::RayQuery_Abort:         A(pV);       A(pI32); A(pI32); break;
+  case OpCode::RayQuery_CommitNonOpaqueTriangleHit:A(pV);       A(pI32); A(pI32); break;
+  case OpCode::RayQuery_CommitProceduralPrimitiveHit:A(pV);       A(pI32); A(pI32); A(pF32); break;
+  case OpCode::RayQuery_CommittedStatus:A(pI32);     A(pI32); A(pI32); break;
+  case OpCode::RayQuery_CandidateType: A(pI32);     A(pI32); A(pI32); break;
+  case OpCode::RayQuery_CandidateObjectToWorld3x4:A(pF32);     A(pI32); A(pI32); A(pI32); A(pI8);  break;
+  case OpCode::RayQuery_CandidateWorldToObject3x4:A(pF32);     A(pI32); A(pI32); A(pI32); A(pI8);  break;
+  case OpCode::RayQuery_CommittedObjectToWorld3x4:A(pF32);     A(pI32); A(pI32); A(pI32); A(pI8);  break;
+  case OpCode::RayQuery_CommittedWorldToObject3x4:A(pF32);     A(pI32); A(pI32); A(pI32); A(pI8);  break;
+  case OpCode::RayQuery_CandidateProceduralPrimitiveNonOpaque:A(pI1);      A(pI32); A(pI32); break;
+  case OpCode::RayQuery_CandidateTriangleFrontFace:A(pI1);      A(pI32); A(pI32); break;
+  case OpCode::RayQuery_CommittedTriangleFrontFace:A(pI1);      A(pI32); A(pI32); break;
+  case OpCode::RayQuery_CandidateTriangleBarycentrics:A(pF32);     A(pI32); A(pI32); A(pI8);  break;
+  case OpCode::RayQuery_CommittedTriangleBarycentrics:A(pF32);     A(pI32); A(pI32); A(pI8);  break;
+  case OpCode::RayQuery_RayFlags:      A(pI32);     A(pI32); A(pI32); break;
+  case OpCode::RayQuery_WorldRayOrigin:A(pF32);     A(pI32); A(pI32); A(pI8);  break;
+  case OpCode::RayQuery_WorldRayDirection:A(pF32);     A(pI32); A(pI32); A(pI8);  break;
+  case OpCode::RayQuery_RayTMin:       A(pF32);     A(pI32); A(pI32); break;
+  case OpCode::RayQuery_CandidateTriangleRayT:A(pF32);     A(pI32); A(pI32); break;
+  case OpCode::RayQuery_CommittedRayT: A(pF32);     A(pI32); A(pI32); break;
+  case OpCode::RayQuery_CandidateInstanceIndex:A(pI32);     A(pI32); A(pI32); break;
+  case OpCode::RayQuery_CandidateInstanceID:A(pI32);     A(pI32); A(pI32); break;
+  case OpCode::RayQuery_CandidateGeometryIndex:A(pI32);     A(pI32); A(pI32); break;
+  case OpCode::RayQuery_CandidatePrimitiveIndex:A(pI32);     A(pI32); A(pI32); break;
+  case OpCode::RayQuery_CandidateObjectRayOrigin:A(pF32);     A(pI32); A(pI32); A(pI8);  break;
+  case OpCode::RayQuery_CandidateObjectRayDirection:A(pF32);     A(pI32); A(pI32); A(pI8);  break;
+  case OpCode::RayQuery_CommittedInstanceIndex:A(pI32);     A(pI32); A(pI32); break;
+  case OpCode::RayQuery_CommittedInstanceID:A(pI32);     A(pI32); A(pI32); break;
+  case OpCode::RayQuery_CommittedGeometryIndex:A(pI32);     A(pI32); A(pI32); break;
+  case OpCode::RayQuery_CommittedPrimitiveIndex:A(pI32);     A(pI32); A(pI32); break;
+  case OpCode::RayQuery_CommittedObjectRayOrigin:A(pF32);     A(pI32); A(pI32); A(pI8);  break;
+  case OpCode::RayQuery_CommittedObjectRayDirection:A(pF32);     A(pI32); A(pI32); A(pI8);  break;
+
+    // Raytracing object space uint System Values, raytracing tier 1.1
+  case OpCode::GeometryIndex:          A(pI32);     A(pI32); break;
   // OPCODE-OLOAD-FUNCS:END
   default: DXASSERT(false, "otherwise unhandled case"); break;
   }
@@ -1152,6 +1317,9 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
   case OpCode::BufferStore:
   case OpCode::StorePatchConstant:
   case OpCode::RawBufferStore:
+  case OpCode::StoreVertexOutput:
+  case OpCode::StorePrimitiveOutput:
+  case OpCode::DispatchMesh:
     DXASSERT_NOMSG(FT->getNumParams() > 4);
     return FT->getParamType(4);
   case OpCode::IsNaN:
@@ -1215,6 +1383,17 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
   case OpCode::IgnoreHit:
   case OpCode::AcceptHitAndEndSearch:
   case OpCode::WaveMultiPrefixBitCount:
+  case OpCode::SetMeshOutputCounts:
+  case OpCode::EmitIndices:
+  case OpCode::WriteSamplerFeedback:
+  case OpCode::WriteSamplerFeedbackBias:
+  case OpCode::WriteSamplerFeedbackLevel:
+  case OpCode::WriteSamplerFeedbackGrad:
+  case OpCode::AllocateRayQuery:
+  case OpCode::RayQuery_TraceRayInline:
+  case OpCode::RayQuery_Abort:
+  case OpCode::RayQuery_CommitNonOpaqueTriangleHit:
+  case OpCode::RayQuery_CommitProceduralPrimitiveHit:
     return Type::getVoidTy(m_Ctx);
   case OpCode::CheckAccessFullyMapped:
   case OpCode::AtomicBinOp:
@@ -1239,6 +1418,18 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
   case OpCode::PrimitiveIndex:
   case OpCode::Dot4AddI8Packed:
   case OpCode::Dot4AddU8Packed:
+  case OpCode::RayQuery_CommittedStatus:
+  case OpCode::RayQuery_CandidateType:
+  case OpCode::RayQuery_RayFlags:
+  case OpCode::RayQuery_CandidateInstanceIndex:
+  case OpCode::RayQuery_CandidateInstanceID:
+  case OpCode::RayQuery_CandidateGeometryIndex:
+  case OpCode::RayQuery_CandidatePrimitiveIndex:
+  case OpCode::RayQuery_CommittedInstanceIndex:
+  case OpCode::RayQuery_CommittedInstanceID:
+  case OpCode::RayQuery_CommittedGeometryIndex:
+  case OpCode::RayQuery_CommittedPrimitiveIndex:
+  case OpCode::GeometryIndex:
     return IntegerType::get(m_Ctx, 32);
   case OpCode::CalculateLOD:
   case OpCode::DomainLocation:
@@ -1250,10 +1441,30 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
   case OpCode::WorldToObject:
   case OpCode::RayTMin:
   case OpCode::RayTCurrent:
+  case OpCode::RayQuery_CandidateObjectToWorld3x4:
+  case OpCode::RayQuery_CandidateWorldToObject3x4:
+  case OpCode::RayQuery_CommittedObjectToWorld3x4:
+  case OpCode::RayQuery_CommittedWorldToObject3x4:
+  case OpCode::RayQuery_CandidateTriangleBarycentrics:
+  case OpCode::RayQuery_CommittedTriangleBarycentrics:
+  case OpCode::RayQuery_WorldRayOrigin:
+  case OpCode::RayQuery_WorldRayDirection:
+  case OpCode::RayQuery_RayTMin:
+  case OpCode::RayQuery_CandidateTriangleRayT:
+  case OpCode::RayQuery_CommittedRayT:
+  case OpCode::RayQuery_CandidateObjectRayOrigin:
+  case OpCode::RayQuery_CandidateObjectRayDirection:
+  case OpCode::RayQuery_CommittedObjectRayOrigin:
+  case OpCode::RayQuery_CommittedObjectRayDirection:
     return Type::getFloatTy(m_Ctx);
   case OpCode::MakeDouble:
   case OpCode::SplitDouble:
     return Type::getDoubleTy(m_Ctx);
+  case OpCode::RayQuery_Proceed:
+  case OpCode::RayQuery_CandidateProceduralPrimitiveNonOpaque:
+  case OpCode::RayQuery_CandidateTriangleFrontFace:
+  case OpCode::RayQuery_CommittedTriangleFrontFace:
+    return IntegerType::get(m_Ctx, 1);
   case OpCode::CBufferLoadLegacy:
   case OpCode::Sample:
   case OpCode::SampleBias:

+ 16 - 0
lib/DXIL/DxilResource.cpp

@@ -146,6 +146,10 @@ unsigned DxilResource::GetNumCoords(Kind ResourceKind) {
       0, // Sampler,
       1, // TBuffer,
       0, // RaytracingAccelerationStructure,
+      2, // FeedbackTexture2DMinLOD,
+      2, // FeedbackTexture2DTiled,
+      3, // FeedbackTexture2DArrayMinLOD,
+      3, // FeedbackTexture2DArrayTiled,
   };
   static_assert(_countof(CoordSizeTab) == (unsigned)Kind::NumEntries, "check helper array size");
   DXASSERT(ResourceKind > Kind::Invalid && ResourceKind < Kind::NumEntries, "otherwise the caller passed wrong resource type");
@@ -171,6 +175,10 @@ unsigned DxilResource::GetNumDimensions(Kind ResourceKind) {
       0, // Sampler,
       1, // TBuffer,
       0, // RaytracingAccelerationStructure,
+      2, // FeedbackTexture2DMinLOD,
+      2, // FeedbackTexture2DTiled,
+      2, // FeedbackTexture2DArrayMinLOD,
+      2, // FeedbackTexture2DArrayTiled,
   };
   static_assert(_countof(NumDimTab) == (unsigned)Kind::NumEntries, "check helper array size");
   DXASSERT(ResourceKind > Kind::Invalid && ResourceKind < Kind::NumEntries, "otherwise the caller passed wrong resource type");
@@ -196,6 +204,10 @@ unsigned DxilResource::GetNumDimensionsForCalcLOD(Kind ResourceKind) {
       0, // Sampler,
       1, // TBuffer,
       0, // RaytracingAccelerationStructure,
+      2, // FeedbackTexture2DMinLOD,
+      2, // FeedbackTexture2DTiled,
+      2, // FeedbackTexture2DArrayMinLOD,
+      2, // FeedbackTexture2DArrayTiled,
   };
   static_assert(_countof(NumDimTab) == (unsigned)Kind::NumEntries, "check helper array size");
   DXASSERT(ResourceKind > Kind::Invalid && ResourceKind < Kind::NumEntries, "otherwise the caller passed wrong resource type");
@@ -221,6 +233,10 @@ unsigned DxilResource::GetNumOffsets(Kind ResourceKind) {
       0, // Sampler,
       1, // TBuffer,
       0, // RaytracingAccelerationStructure,
+      2, // FeedbackTexture2DMinLOD,
+      2, // FeedbackTexture2DTiled,
+      2, // FeedbackTexture2DArrayMinLOD,
+      2, // FeedbackTexture2DArrayTiled,
   };
   static_assert(_countof(OffsetSizeTab) == (unsigned)Kind::NumEntries, "check helper array size");
   DXASSERT(ResourceKind > Kind::Invalid && ResourceKind < Kind::NumEntries, "otherwise the caller passed wrong resource type");

+ 15 - 7
lib/DXIL/DxilResourceBase.cpp

@@ -56,36 +56,44 @@ void DxilResourceBase::SetGlobalSymbol(llvm::Constant *pGV)       { m_pSymbol =
 void DxilResourceBase::SetGlobalName(const std::string &Name)     { m_Name = Name; }
 void DxilResourceBase::SetHandle(llvm::Value *pHandle)            { m_pHandle = pHandle; }
 
-static const char *s_ResourceClassNames[(unsigned)DxilResourceBase::Class::Invalid] = {
+static const char *s_ResourceClassNames[] = {
     "texture", "UAV", "cbuffer", "sampler"
 };
+static_assert(_countof(s_ResourceClassNames) == (unsigned)DxilResourceBase::Class::Invalid,
+  "Resource class names array must be updated when new resource class enums are added.");
 
 const char *DxilResourceBase::GetResClassName() const {
   return s_ResourceClassNames[(unsigned)m_Class];
 }
 
-static const char *s_ResourceIDPrefixs[(unsigned)DxilResourceBase::Class::Invalid] = {
+static const char *s_ResourceIDPrefixes[] = {
     "T", "U", "CB", "S"
 };
+static_assert(_countof(s_ResourceIDPrefixes) == (unsigned)DxilResourceBase::Class::Invalid,
+  "Resource id prefixes array must be updated when new resource class enums are added.");
 
 const char *DxilResourceBase::GetResIDPrefix() const {
-  return s_ResourceIDPrefixs[(unsigned)m_Class];
+  return s_ResourceIDPrefixes[(unsigned)m_Class];
 }
 
-static const char *s_ResourceBindPrefixs[(unsigned)DxilResourceBase::Class::Invalid] = {
+static const char *s_ResourceBindPrefixes[] = {
     "t", "u", "cb", "s"
 };
+static_assert(_countof(s_ResourceBindPrefixes) == (unsigned)DxilResourceBase::Class::Invalid,
+  "Resource bind prefixes array must be updated when new resource class enums are added.");
 
 const char *DxilResourceBase::GetResBindPrefix() const {
-  return s_ResourceBindPrefixs[(unsigned)m_Class];
+  return s_ResourceBindPrefixes[(unsigned)m_Class];
 }
 
-static const char *s_ResourceDimNames[(unsigned)DxilResourceBase::Kind::NumEntries] = {
+static const char *s_ResourceDimNames[] = {
         "invalid", "1d",        "2d",      "2dMS",      "3d",
         "cube",    "1darray",   "2darray", "2darrayMS", "cubearray",
         "buf",     "rawbuf",    "structbuf", "cbuffer", "sampler",
-        "tbuffer", "ras",
+        "tbuffer", "ras", "fbtex2dML", "fbtex2dT", "fbtex2darrayML", "fbtex2darrayT"
 };
+static_assert(_countof(s_ResourceDimNames) == (unsigned)DxilResourceBase::Kind::NumEntries,
+  "Resource dim names array must be updated when new resource kind enums are added.");
 
 const char *DxilResourceBase::GetResDimName() const {
   return s_ResourceDimNames[(unsigned)m_Kind];

+ 1 - 0
lib/DXIL/DxilSemantic.cpp

@@ -146,6 +146,7 @@ const Semantic Semantic::ms_SemanticTable[kNumSemanticRecords] = {
   SP(Kind::ViewID,                "SV_ViewID"),
   SP(Kind::Barycentrics,          "SV_Barycentrics"),
   SP(Kind::ShadingRate,           "SV_ShadingRate"),
+  SP(Kind::CullPrimitive,         "SV_CullPrimitive"),
   SP(Kind::Invalid,               nullptr),
 };
 

+ 14 - 0
lib/DXIL/DxilShaderFlags.cpp

@@ -47,6 +47,8 @@ ShaderFlags::ShaderFlags():
 , m_bBarycentrics(false)
 , m_bUseNativeLowPrecision(false)
 , m_bShadingRate(false)
+, m_bSamplerFeedback(false)
+, m_bRaytracingTier1_1(false)
 , m_align0(0)
 , m_align1(0)
 {}
@@ -93,6 +95,8 @@ uint64_t ShaderFlags::GetFeatureInfo() const {
   Flags |= m_bViewID ? hlsl::DXIL::ShaderFeatureInfo_ViewID : 0;
   Flags |= m_bBarycentrics ? hlsl::DXIL::ShaderFeatureInfo_Barycentrics : 0;
   Flags |= m_bShadingRate ? hlsl::DXIL::ShaderFeatureInfo_ShadingRate : 0;
+  Flags |= m_bRaytracingTier1_1 ? hlsl::DXIL::ShaderFeatureInfo_Raytracing_Tier_1_1 : 0;
+  Flags |= m_bSamplerFeedback ? hlsl::DXIL::ShaderFeatureInfo_SamplerFeedback : 0;
 
   return Flags;
 }
@@ -145,6 +149,8 @@ uint64_t ShaderFlags::GetShaderFlagsRawForCollection() {
   Flags.SetViewID(true);
   Flags.SetBarycentrics(true);
   Flags.SetShadingRate(true);
+  Flags.SetRaytracingTier1_1(true);
+  Flags.SetSamplerFeedback(true);
   return Flags.GetShaderFlagsRaw();
 }
 
@@ -247,6 +253,8 @@ ShaderFlags ShaderFlags::CollectShaderFlags(const Function *F,
   bool hasMulticomponentUAVLoads = false;
   bool hasViewportOrRTArrayIndex = false;
   bool hasShadingRate = false;
+  bool hasSamplerFeedback = false;
+  bool hasRaytracingTier1_1 = false;
 
   // Try to maintain compatibility with a v1.0 validator if that's what we have.
   uint32_t valMajor, valMinor;
@@ -381,6 +389,10 @@ ShaderFlags ShaderFlags::CollectShaderFlags(const Function *F,
         case DXIL::OpCode::ViewID:
           hasViewID = true;
           break;
+        case DXIL::OpCode::AllocateRayQuery:
+        case DXIL::OpCode::GeometryIndex:
+          hasRaytracingTier1_1 = true;
+          break;
         default:
           // Normal opcodes.
           break;
@@ -460,6 +472,8 @@ ShaderFlags ShaderFlags::CollectShaderFlags(const Function *F,
   flag.SetViewID(hasViewID);
   flag.SetViewportAndRTArrayIndex(hasViewportOrRTArrayIndex);
   flag.SetShadingRate(hasShadingRate);
+  flag.SetSamplerFeedback(hasSamplerFeedback);
+  flag.SetRaytracingTier1_1(hasRaytracingTier1_1);
 
   return flag;
 }

+ 8 - 3
lib/DXIL/DxilShaderModel.cpp

@@ -43,7 +43,7 @@ bool ShaderModel::operator==(const ShaderModel &other) const {
 
 bool ShaderModel::IsValid() const {
   DXASSERT(IsPS() || IsVS() || IsGS() || IsHS() || IsDS() || IsCS() ||
-               IsLib() || m_Kind == Kind::Invalid,
+               IsLib() || IsMS() || IsAS() || m_Kind == Kind::Invalid,
            "invalid shader model");
   return m_Kind != Kind::Invalid;
 }
@@ -103,7 +103,7 @@ const ShaderModel *ShaderModel::Get(Kind Kind, unsigned Major, unsigned Minor) {
 }
 
 const ShaderModel *ShaderModel::GetByName(const char *pszName) {
-  // [ps|vs|gs|hs|ds|cs]_[major]_[minor]
+  // [ps|vs|gs|hs|ds|cs|ms|as]_[major]_[minor]
   Kind kind;
   switch (pszName[0]) {
   case 'p':   kind = Kind::Pixel;     break;
@@ -113,6 +113,8 @@ const ShaderModel *ShaderModel::GetByName(const char *pszName) {
   case 'd':   kind = Kind::Domain;    break;
   case 'c':   kind = Kind::Compute;   break;
   case 'l':   kind = Kind::Library;   break;
+  case 'm':   kind = Kind::Mesh;      break;
+  case 'a':   kind = Kind::Amplification; break;
   default:    return GetInvalid();
   }
   unsigned Idx = 3;
@@ -241,7 +243,7 @@ void ShaderModel::GetMinValidatorVersion(unsigned &ValMajor, unsigned &ValMinor)
 static const char *ShaderModelKindNames[] = {
     "ps", "vs", "gs", "hs", "ds", "cs", "lib",
     "raygeneration", "intersection", "anyhit", "closesthit", "miss", "callable",
-    "invalid",
+    "ms", "as", "invalid",
 };
 
 const char * ShaderModel::GetKindName() const {
@@ -331,6 +333,9 @@ const ShaderModel ShaderModel::ms_ShaderModels[kNumShaderModels] = {
   // lib_6_x is for offline linking only, and relaxes restrictions
   SM(Kind::Library,  6, kOfflineMinor, "lib_6_x",  32, 32,  true,  true,  UINT_MAX),
 
+  SM(Kind::Mesh,     6, 5, "ms_6_5",    0,  0,  true,  true,  UINT_MAX),
+  SM(Kind::Amplification, 6, 5, "as_6_5", 0, 0, true,  true,  UINT_MAX),
+
   // Values before Invalid must remain sorted by Kind, then Major, then Minor.
 
   SM(Kind::Invalid,  0, 0, "invalid", 0,  0,   false, false, 0),

+ 12 - 1
lib/DXIL/DxilSignature.cpp

@@ -118,13 +118,24 @@ unsigned DxilSignature::NumVectorsUsed(unsigned streamIndex) const {
   return NumVectors;
 }
 
+unsigned DxilSignature::GetRowCount() const {
+  unsigned maxRow = 0;
+  for (auto &E : GetElements()) {
+    unsigned endRow = E->GetStartRow() + E->GetRows();
+    if (maxRow < endRow) {
+      maxRow = endRow;
+    }
+  }
+  return maxRow;
+}
+
 //------------------------------------------------------------------------------
 //
 // EntrySingnature methods.
 //
 DxilEntrySignature::DxilEntrySignature(const DxilEntrySignature &src)
     : InputSignature(src.InputSignature), OutputSignature(src.OutputSignature),
-      PatchConstantSignature(src.PatchConstantSignature) {}
+      PatchConstOrPrimSignature(src.PatchConstOrPrimSignature) {}
 
 } // namespace hlsl
 

+ 2 - 2
lib/DXIL/DxilSignatureElement.cpp

@@ -88,8 +88,8 @@ bool DxilSignatureElement::IsOutput() const {
   return SigPoint::GetSigPoint(m_sigPointKind)->IsOutput();
 }
 
-bool DxilSignatureElement::IsPatchConstant() const {
-  return SigPoint::GetSigPoint(m_sigPointKind)->IsPatchConstant();
+bool DxilSignatureElement::IsPatchConstOrPrim() const {
+  return SigPoint::GetSigPoint(m_sigPointKind)->IsPatchConstOrPrim();
 }
 
 const char *DxilSignatureElement::GetName() const {

+ 56 - 1
lib/DXIL/DxilTypeSystem.cpp

@@ -78,6 +78,22 @@ const std::string &DxilFieldAnnotation::GetFieldName() const { return m_FieldNam
 void DxilFieldAnnotation::SetFieldName(const std::string &FieldName) { m_FieldName = FieldName; }
 
 
+//------------------------------------------------------------------------------
+//
+// DxilStructAnnotation class methods.
+//
+DxilTemplateArgAnnotation::DxilTemplateArgAnnotation()
+    : DxilFieldAnnotation(), m_Type(nullptr), m_Integral(0)
+{}
+
+bool DxilTemplateArgAnnotation::IsType() const { return m_Type != nullptr; }
+const llvm::Type *DxilTemplateArgAnnotation::GetType() const { return m_Type; }
+void DxilTemplateArgAnnotation::SetType(const llvm::Type *pType) { m_Type = pType; }
+
+bool DxilTemplateArgAnnotation::IsIntegral() const { return m_Type == nullptr; }
+int64_t DxilTemplateArgAnnotation::GetIntegral() const { return m_Integral; }
+void DxilTemplateArgAnnotation::SetIntegral(int64_t i64) { m_Type = nullptr; m_Integral = i64; }
+
 //------------------------------------------------------------------------------
 //
 // DxilStructAnnotation class methods.
@@ -107,6 +123,22 @@ void DxilStructAnnotation::SetCBufferSize(unsigned size) { m_CBufferSize = size;
 void DxilStructAnnotation::MarkEmptyStruct() { m_FieldAnnotations.clear(); }
 bool DxilStructAnnotation::IsEmptyStruct() { return m_FieldAnnotations.empty(); }
 
+// For template args, GetNumTemplateArgs() will return 0 if not a template
+unsigned DxilStructAnnotation::GetNumTemplateArgs() const {
+  return (unsigned)m_TemplateAnnotations.size();
+}
+void DxilStructAnnotation::SetNumTemplateArgs(unsigned count) {
+  DXASSERT(m_TemplateAnnotations.empty(), "template args already initialized");
+  m_TemplateAnnotations.resize(count);
+}
+DxilTemplateArgAnnotation &DxilStructAnnotation::GetTemplateArgAnnotation(unsigned argIdx) {
+  return m_TemplateAnnotations[argIdx];
+}
+const DxilTemplateArgAnnotation &DxilStructAnnotation::GetTemplateArgAnnotation(unsigned argIdx) const {
+  return m_TemplateAnnotations[argIdx];
+}
+
+
 //------------------------------------------------------------------------------
 //
 // DxilParameterAnnotation class methods.
@@ -170,12 +202,13 @@ DxilTypeSystem::DxilTypeSystem(Module *pModule)
     : m_pModule(pModule),
       m_LowPrecisionMode(DXIL::LowPrecisionMode::Undefined) {}
 
-DxilStructAnnotation *DxilTypeSystem::AddStructAnnotation(const StructType *pStructType) {
+DxilStructAnnotation *DxilTypeSystem::AddStructAnnotation(const StructType *pStructType, unsigned numTemplateArgs) {
   DXASSERT_NOMSG(m_StructAnnotations.find(pStructType) == m_StructAnnotations.end());
   DxilStructAnnotation *pA = new DxilStructAnnotation();
   m_StructAnnotations[pStructType] = unique_ptr<DxilStructAnnotation>(pA);
   pA->m_pStructType = pStructType;
   pA->m_FieldAnnotations.resize(pStructType->getNumElements());
+  pA->SetNumTemplateArgs(numTemplateArgs);
   return pA;
 }
 
@@ -410,6 +443,28 @@ DXIL::SigPointKind SigPointFromInputQual(DxilParamInputQual Q, DXIL::ShaderKind
       break;
     }
     break;
+  case DXIL::ShaderKind::Mesh:
+    switch (Q) {
+    case DxilParamInputQual::In:
+    case DxilParamInputQual::InPayload:
+      return DXIL::SigPointKind::MSIn;
+    case DxilParamInputQual::OutIndices:
+    case DxilParamInputQual::OutVertices:
+      return DXIL::SigPointKind::MSOut;
+    case DxilParamInputQual::OutPrimitives:
+      return DXIL::SigPointKind::MSPOut;
+    default:
+      break;
+    }
+    break;
+  case DXIL::ShaderKind::Amplification:
+    switch (Q) {
+    case DxilParamInputQual::In:
+      return DXIL::SigPointKind::ASIn;
+    default:
+      break;
+    }
+    break;
   default:
     break;
   }

+ 30 - 8
lib/DXIL/DxilUtil.cpp

@@ -467,11 +467,17 @@ llvm::Instruction *FirstNonAllocaInsertionPt(llvm::Function* F) {
   return SkipAllocas(FindAllocaInsertionPt(F));
 }
 
+static bool ConsumePrefix(StringRef &Str, StringRef Prefix) {
+  if (!Str.startswith(Prefix)) return false;
+  Str = Str.substr(Prefix.size());
+  return true;
+}
+
 bool IsHLSLResourceType(llvm::Type *Ty) {
   if (llvm::StructType *ST = dyn_cast<llvm::StructType>(Ty)) {
     StringRef name = ST->getName();
-    name = name.ltrim("class.");
-    name = name.ltrim("struct.");
+    ConsumePrefix(name, "class.");
+    ConsumePrefix(name, "struct.");
 
     if (name == "SamplerState")
       return true;
@@ -489,8 +495,13 @@ bool IsHLSLResourceType(llvm::Type *Ty) {
     if (name == "RaytracingAccelerationStructure")
       return true;
 
-    name = name.ltrim("RasterizerOrdered");
-    name = name.ltrim("RW");
+    if (ConsumePrefix(name, "FeedbackTexture2D")) {
+      ConsumePrefix(name, "Array");
+      return name == "MinLOD" || name == "Tiled";
+    }
+
+    ConsumePrefix(name, "RasterizerOrdered");
+    ConsumePrefix(name, "RW");
     if (name == "ByteAddressBuffer")
       return true;
 
@@ -499,8 +510,7 @@ bool IsHLSLResourceType(llvm::Type *Ty) {
     if (name.startswith("StructuredBuffer<"))
       return true;
 
-    if (name.startswith("Texture")) {
-      name = name.ltrim("Texture");
+    if (ConsumePrefix(name, "Texture")) {
       if (name.startswith("1D<"))
         return true;
       if (name.startswith("1DArray<"))
@@ -519,6 +529,7 @@ bool IsHLSLResourceType(llvm::Type *Ty) {
         return true;
       if (name.startswith("2DMSArray<"))
         return true;
+      return false;
     }
   }
   return false;
@@ -537,8 +548,8 @@ bool IsHLSLObjectType(llvm::Type *Ty) {
     if (IsHLSLResourceType(Ty))
       return true;
 
-    name = name.ltrim("class.");
-    name = name.ltrim("struct.");
+    ConsumePrefix(name, "class.");
+    ConsumePrefix(name, "struct.");
 
     if (name.startswith("TriangleStream<"))
       return true;
@@ -550,6 +561,17 @@ bool IsHLSLObjectType(llvm::Type *Ty) {
   return false;
 }
 
+bool IsHLSLRayQueryType(llvm::Type *Ty) {
+  if (llvm::StructType *ST = dyn_cast<llvm::StructType>(Ty)) {
+    StringRef name = ST->getName();
+    // TODO: don't check names.
+    ConsumePrefix(name, "class.");
+    if (name.startswith("RayQuery<"))
+      return true;
+  }
+  return false;
+}
+
 bool IsIntegerOrFloatingPointType(llvm::Type *Ty) {
   return Ty->isIntegerTy() || Ty->isFloatingPointTy();
 }

+ 71 - 25
lib/DxilContainer/DxilContainerAssembler.cpp

@@ -13,6 +13,7 @@
 #include "llvm/ADT/SetVector.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/DebugInfo.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/Bitcode/ReaderWriter.h"
 #include "llvm/Support/MD5.h"
 #include "llvm/ADT/STLExtras.h"
@@ -312,9 +313,9 @@ DxilPartWriter *hlsl::NewProgramSignatureWriter(const DxilModule &M, DXIL::Signa
     return new DxilProgramSignatureWriter(
         M.GetOutputSignature(), domain, false,
         M.GetUseMinPrecision());
-  case DXIL::SignatureKind::PatchConstant:
+  case DXIL::SignatureKind::PatchConstOrPrim:
     return new DxilProgramSignatureWriter(
-        M.GetPatchConstantSignature(), domain,
+        M.GetPatchConstOrPrimSignature(), domain,
         /*IsInput*/ M.GetShaderModel()->IsDS(),
         /*UseMinPrecision*/M.GetUseMinPrecision());
   case DXIL::SignatureKind::Invalid:
@@ -372,7 +373,7 @@ private:
   SmallVector<uint32_t, 8> m_SemanticIndexBuffer;
   std::vector<PSVSignatureElement0> m_SigInputElements;
   std::vector<PSVSignatureElement0> m_SigOutputElements;
-  std::vector<PSVSignatureElement0> m_SigPatchConstantElements;
+  std::vector<PSVSignatureElement0> m_SigPatchConstOrPrimElements;
 
   void SetPSVSigElement(PSVSignatureElement0 &E, const DxilSignatureElement &SE) {
     memset(&E, 0, sizeof(PSVSignatureElement0));
@@ -468,8 +469,8 @@ public:
       m_SigInputElements.resize(m_PSVInitInfo.SigInputElements);
       m_PSVInitInfo.SigOutputElements = m_Module.GetOutputSignature().GetElements().size();
       m_SigOutputElements.resize(m_PSVInitInfo.SigOutputElements);
-      m_PSVInitInfo.SigPatchConstantElements = m_Module.GetPatchConstantSignature().GetElements().size();
-      m_SigPatchConstantElements.resize(m_PSVInitInfo.SigPatchConstantElements);
+      m_PSVInitInfo.SigPatchConstOrPrimElements = m_Module.GetPatchConstOrPrimSignature().GetElements().size();
+      m_SigPatchConstOrPrimElements.resize(m_PSVInitInfo.SigPatchConstOrPrimElements);
       uint32_t i = 0;
       for (auto &SE : m_Module.GetInputSignature().GetElements()) {
         SetPSVSigElement(m_SigInputElements[i++], *(SE.get()));
@@ -479,8 +480,8 @@ public:
         SetPSVSigElement(m_SigOutputElements[i++], *(SE.get()));
       }
       i = 0;
-      for (auto &SE : m_Module.GetPatchConstantSignature().GetElements()) {
-        SetPSVSigElement(m_SigPatchConstantElements[i++], *(SE.get()));
+      for (auto &SE : m_Module.GetPatchConstOrPrimSignature().GetElements()) {
+        SetPSVSigElement(m_SigPatchConstOrPrimElements[i++], *(SE.get()));
       }
       // Set String and SemanticInput Tables
       m_PSVInitInfo.StringTable.Table = m_StringBuffer.data();
@@ -493,12 +494,9 @@ public:
       for (unsigned streamIndex = 0; streamIndex < 4; streamIndex++) {
         m_PSVInitInfo.SigOutputVectors[streamIndex] = m_Module.GetOutputSignature().NumVectorsUsed(streamIndex);
       }
-      m_PSVInitInfo.SigPatchConstantVectors = 0;
-      if (SM->IsHS()) {
-        m_PSVInitInfo.SigPatchConstantVectors = m_Module.GetPatchConstantSignature().NumVectorsUsed(0);
-      }
-      if (SM->IsDS()) {
-        m_PSVInitInfo.SigPatchConstantVectors = m_Module.GetPatchConstantSignature().NumVectorsUsed(0);
+      m_PSVInitInfo.SigPatchConstOrPrimVectors = 0;
+      if (SM->IsHS() || SM->IsDS() || SM->IsMS()) {
+        m_PSVInitInfo.SigPatchConstOrPrimVectors = m_Module.GetPatchConstOrPrimSignature().NumVectorsUsed(0);
       }
     }
     if (!m_PSV.InitNew(m_PSVInitInfo, nullptr, &m_PSVBufferSize)) {
@@ -603,10 +601,58 @@ public:
         }
         break;
       }
+    case ShaderModel::Kind::Amplification:
     case ShaderModel::Kind::Compute:
     case ShaderModel::Kind::Library:
     case ShaderModel::Kind::Invalid:
-      // Compute, Library, and Invalide not relevant to PSVRuntimeInfo0
+      // Amplification, Compute, Library, and Invalide not relevant to PSVRuntimeInfo0
+      break;
+    case ShaderModel::Kind::Mesh:
+      pInfo->MS.MaxOutputVertices = (UINT)m_Module.GetMaxOutputVertices();
+      pInfo->MS.MaxOutputPrimitives = (UINT)m_Module.GetMaxOutputPrimitives();
+      pInfo1->MeshOutputTopology = (UINT)m_Module.GetMeshOutputTopology();
+      Module *mod = m_Module.GetModule();
+      const DataLayout &DL = mod->getDataLayout();
+      unsigned totalByteSize = 0;
+      for (GlobalVariable &GV : mod->globals()) {
+        PointerType *gvPtrType = cast<PointerType>(GV.getType());
+        if (gvPtrType->getAddressSpace() == hlsl::DXIL::kTGSMAddrSpace) {
+          Type *gvType = gvPtrType->getPointerElementType();
+          unsigned byteSize = DL.getTypeAllocSize(gvType);
+          totalByteSize += byteSize;
+        }
+      }
+      pInfo->MS.GroupSharedBytesUsed = totalByteSize;
+
+      const Function *entryFunc = m_Module.GetEntryFunction();
+      unsigned payloadByteSize = 0;
+      for (auto b = entryFunc->begin(), bend = entryFunc->end(); b != bend; ++b) {
+        auto i = b->begin(), iend = b->end();
+        for (; i != iend; ++i) {
+          const Instruction &I = *i;
+
+          // Calls to external functions.
+          const CallInst *CI = dyn_cast<CallInst>(&I);
+          if (CI) {
+            Function *FCalled = CI->getCalledFunction();
+            if (FCalled->isDeclaration()) {
+              Value *opcodeVal = CI->getOperand(0);
+              ConstantInt *OpcodeConst = dyn_cast<ConstantInt>(opcodeVal);
+              unsigned opcode = OpcodeConst->getLimitedValue();
+              DXIL::OpCode dxilOpcode = (DXIL::OpCode)opcode;
+              if (dxilOpcode == DXIL::OpCode::GetMeshPayload) {
+                PointerType *payloadPTy = cast<PointerType>(CI->getType());
+                Type *payloadTy = payloadPTy->getPointerElementType();
+                payloadByteSize = DL.getTypeAllocSize(payloadTy);
+                break;
+              }
+            }
+          }
+        }
+        if (i != iend)
+          break;
+      }
+      pInfo->MS.PayloadSizeInBytes = payloadByteSize;
       break;
     }
 
@@ -689,10 +735,10 @@ public:
         DXASSERT_NOMSG(pOutputElement);
         memcpy(pOutputElement, &m_SigOutputElements[i], sizeof(PSVSignatureElement0));
       }
-      for (unsigned i = 0; i < m_PSV.GetSigPatchConstantElements(); i++) {
-        PSVSignatureElement0 *pPatchConstantElement = m_PSV.GetPatchConstantElement0(i);
-        DXASSERT_NOMSG(pPatchConstantElement);
-        memcpy(pPatchConstantElement, &m_SigPatchConstantElements[i], sizeof(PSVSignatureElement0));
+      for (unsigned i = 0; i < m_PSV.GetSigPatchConstOrPrimElements(); i++) {
+        PSVSignatureElement0 *pPatchConstOrPrimElement = m_PSV.GetPatchConstOrPrimElement0(i);
+        DXASSERT_NOMSG(pPatchConstOrPrimElement);
+        memcpy(pPatchConstOrPrimElement, &m_SigPatchConstOrPrimElements[i], sizeof(PSVSignatureElement0));
       }
 
       // Gather ViewID dependency information
@@ -707,7 +753,7 @@ public:
           if (!SM->IsGS())
             break;
         }
-        if (SM->IsHS()) {
+        if (SM->IsHS() || SM->IsMS()) {
           const uint32_t PCScalars = *(pSrc++);
           pSrc = CopyViewIDState(pSrc, InputScalars, PCScalars, m_PSV.GetViewIDPCOutputMask(), m_PSV.GetInputToPCOutputTable());
         } else if (SM->IsDS()) {
@@ -1491,7 +1537,7 @@ void hlsl::SerializeDxilContainerForModule(DxilModule *pModule,
 
   std::unique_ptr<DxilProgramSignatureWriter> pInputSigWriter = nullptr;
   std::unique_ptr<DxilProgramSignatureWriter> pOutputSigWriter = nullptr;
-  std::unique_ptr<DxilProgramSignatureWriter> pPatchConstantSigWriter = nullptr;
+  std::unique_ptr<DxilProgramSignatureWriter> pPatchConstOrPrimSigWriter = nullptr;
   if (!pModule->GetShaderModel()->IsLib()) {
     DXIL::TessellatorDomain domain = DXIL::TessellatorDomain::Undefined;
     if (pModule->GetShaderModel()->IsHS() || pModule->GetShaderModel()->IsDS())
@@ -1514,15 +1560,15 @@ void hlsl::SerializeDxilContainerForModule(DxilModule *pModule,
                      pOutputSigWriter->write(pStream);
                    });
 
-    pPatchConstantSigWriter = llvm::make_unique<DxilProgramSignatureWriter>(
-        pModule->GetPatchConstantSignature(), domain,
+    pPatchConstOrPrimSigWriter = llvm::make_unique<DxilProgramSignatureWriter>(
+        pModule->GetPatchConstOrPrimSignature(), domain,
         /*IsInput*/ pModule->GetShaderModel()->IsDS(),
         /*UseMinPrecision*/ pModule->GetUseMinPrecision());
-    if (pModule->GetPatchConstantSignature().GetElements().size()) {
+    if (pModule->GetPatchConstOrPrimSignature().GetElements().size()) {
       writer.AddPart(DFCC_PatchConstantSignature,
-                     pPatchConstantSigWriter->size(),
+                     pPatchConstOrPrimSigWriter->size(),
                      [&](AbstractMemoryStream *pStream) {
-                       pPatchConstantSigWriter->write(pStream);
+                       pPatchConstOrPrimSigWriter->write(pStream);
                      });
     }
   }

+ 56 - 63
lib/DxilPIXPasses/DxilShaderAccessTracking.cpp

@@ -48,7 +48,12 @@ enum class ShaderAccessFlags : uint32_t
 
   // "Counter" access is only applicable to UAVs; it means the counter buffer attached to the UAV
   // was accessed, but not necessarily the UAV resource.
-  Counter = 1 << 2
+  Counter = 1 << 2,
+
+  // Descriptor-only read (if any), but not the resource contents (if any).
+  // Used for GetDimensions, samplers, and secondary texture for sampler feedback.
+  // TODO: Make this a unique value if supported in PIX, then enable GetDimensions
+  DescriptorRead = 1 << 0,
 };
 
 // This enum doesn't have to match PIX's version, because the values are received from PIX encoded in ASCII.
@@ -435,73 +440,61 @@ bool DxilShaderAccessTracking::runOnModule(Module &M)
       DM.ReEmitDxilResources();
     }
 
-    struct ResourceAccessFunction
-    {
-      DXIL::OpCode opcode;
-      ShaderAccessFlags readWrite;
-      bool functionUsesSamplerAtIndex2;
-      std::vector<Type*> overloads;
-    };
-
-    std::vector<Type*> voidType = { Type::getVoidTy(Ctx) };
-    std::vector<Type*> i32 = { Type::getInt32Ty(Ctx) };
-    std::vector<Type*> f16f32 = { Type::getHalfTy(Ctx), Type::getFloatTy(Ctx) };
-    std::vector<Type*> f32i32 = { Type::getFloatTy(Ctx), Type::getInt32Ty(Ctx) };
-    std::vector<Type*> f32i32f64 = { Type::getFloatTy(Ctx), Type::getInt32Ty(Ctx), Type::getDoubleTy(Ctx) };
-    std::vector<Type*> f16f32i16i32 = { Type::getHalfTy(Ctx), Type::getFloatTy(Ctx), Type::getInt16Ty(Ctx), Type::getInt32Ty(Ctx) };
-    std::vector<Type*> f16f32f64i16i32i64 = { Type::getHalfTy(Ctx), Type::getFloatTy(Ctx), Type::getDoubleTy(Ctx), Type::getInt16Ty(Ctx), Type::getInt32Ty(Ctx), Type::getInt64Ty(Ctx) };
-
-
-    // todo: should "GetDimensions" mean a resource access?
-    static_assert(DXIL::OpCode::NumOpCodes == static_cast<DXIL::OpCode>(168), "Please update PIX passes if any resource access opcodes are added");
-    ResourceAccessFunction raFunctions[] = {
-      { DXIL::OpCode::CBufferLoadLegacy     , ShaderAccessFlags::Read   , false, f32i32f64 },
-      { DXIL::OpCode::CBufferLoad           , ShaderAccessFlags::Read   , false, f16f32f64i16i32i64 },
-      { DXIL::OpCode::Sample                , ShaderAccessFlags::Read   , true , f16f32 },
-      { DXIL::OpCode::SampleBias            , ShaderAccessFlags::Read   , true , f16f32 },
-      { DXIL::OpCode::SampleLevel           , ShaderAccessFlags::Read   , true , f16f32 },
-      { DXIL::OpCode::SampleGrad            , ShaderAccessFlags::Read   , true , f16f32 },
-      { DXIL::OpCode::SampleCmp             , ShaderAccessFlags::Read   , true , f16f32 },
-      { DXIL::OpCode::SampleCmpLevelZero    , ShaderAccessFlags::Read   , true , f16f32 },
-      { DXIL::OpCode::TextureLoad           , ShaderAccessFlags::Read   , false, f16f32i16i32 },
-      { DXIL::OpCode::TextureStore          , ShaderAccessFlags::Write  , false, f16f32i16i32 },
-      { DXIL::OpCode::TextureGather         , ShaderAccessFlags::Read   , true , f16f32i16i32 },
-      { DXIL::OpCode::TextureGatherCmp      , ShaderAccessFlags::Read   , false, f16f32i16i32 },
-      { DXIL::OpCode::BufferLoad            , ShaderAccessFlags::Read   , false, f32i32 },
-      { DXIL::OpCode::RawBufferLoad         , ShaderAccessFlags::Read   , false, f16f32i16i32 },
-      { DXIL::OpCode::RawBufferStore        , ShaderAccessFlags::Write  , false, f16f32i16i32 },
-      { DXIL::OpCode::BufferStore           , ShaderAccessFlags::Write  , false, f32i32 },
-      { DXIL::OpCode::BufferUpdateCounter   , ShaderAccessFlags::Counter, false, voidType },
-      { DXIL::OpCode::AtomicBinOp           , ShaderAccessFlags::Write  , false, i32 },
-      { DXIL::OpCode::AtomicCompareExchange , ShaderAccessFlags::Write  , false, i32 },
-    };
-
-    for (const auto & raFunction : raFunctions) {
-      for (const auto & Overload : raFunction.overloads) {
-        Function * TheFunction = HlslOP->GetOpFunc(raFunction.opcode, Overload);
-        auto TexLoadFunctionUses = TheFunction->uses();
-        for (auto FI = TexLoadFunctionUses.begin(); FI != TexLoadFunctionUses.end(); ) {
-          auto & FunctionUse = *FI++;
-          auto FunctionUser = FunctionUse.getUser();
-          auto instruction = cast<Instruction>(FunctionUser);
-
-          auto res = GetResourceFromHandle(instruction->getOperand(1), DM);
+    for (llvm::Function & F : M.functions()) {
+      // Only used DXIL intrinsics:
+      if (!F.isDeclaration() || F.isIntrinsic() || F.use_empty() || !OP::IsDxilOpFunc(&F))
+        continue;
+
+      // Gather handle parameter indices, if any
+      FunctionType *fnTy = cast<FunctionType>(F.getType()->getPointerElementType());
+      SmallVector<unsigned, 4> handleParams;
+      for (unsigned iParam = 1; iParam < fnTy->getFunctionNumParams(); ++iParam) {
+        if (fnTy->getParamType(iParam) == HlslOP->GetHandleType())
+          handleParams.push_back(iParam);
+      }
+      if (handleParams.empty())
+        continue;
+
+      auto FunctionUses = F.uses();
+      for (auto FI = FunctionUses.begin(); FI != FunctionUses.end(); ) {
+        auto & FunctionUse = *FI++;
+        auto FunctionUser = FunctionUse.getUser();
+        auto Call = cast<CallInst>(FunctionUser);
+        auto opCode = OP::GetDxilOpFuncCallInst(Call);
+
+        // Base Read/Write on function attribute - should match for all normal resource operations
+        ShaderAccessFlags readWrite = ShaderAccessFlags::Write;
+        if (OP::GetMemAccessAttr(opCode) == llvm::Attribute::AttrKind::ReadOnly)
+          readWrite = ShaderAccessFlags::Read;
+
+        // Special cases
+        switch (opCode) {
+        case DXIL::OpCode::GetDimensions:
+          // readWrite = ShaderAccessFlags::DescriptorRead;  // TODO: Support GetDimensions
+          continue;
+        case DXIL::OpCode::BufferUpdateCounter:
+          readWrite = ShaderAccessFlags::Counter;
+          break;
+        case DXIL::OpCode::TraceRay:
+        case DXIL::OpCode::RayQuery_TraceRayInline:
+          // Read of AccelerationStructure; doesn't match function attribute
+          // readWrite = ShaderAccessFlags::Read;  // TODO: Support TraceRay[Inline]
+          continue;
+        default:
+          break;
+        }
 
+        for (unsigned iParam : handleParams) {
+          auto res = GetResourceFromHandle(Call->getArgOperand(iParam), DM);
           // Don't instrument the accesses to the UAV that we just added
-          if (res.resource->GetSpaceID() == (unsigned)-2) {
-            continue;
+          if (res.resClass == DXIL::ResourceClass::UAV && res.resource->GetSpaceID() == (unsigned)-2) {
+            break;
           }
-
-          if (EmitResourceAccess(res, instruction, HlslOP, Ctx, raFunction.readWrite)) {
+          if (EmitResourceAccess(res, Call, HlslOP, Ctx, readWrite)) {
             Modified = true;
           }
-
-          if (raFunction.functionUsesSamplerAtIndex2) {
-            auto sampler = GetResourceFromHandle(instruction->getOperand(2), DM);
-            if (EmitResourceAccess(sampler, instruction, HlslOP, Ctx, ShaderAccessFlags::Read)) {
-              Modified = true;
-            }
-          }
+          // Remaining resources are DescriptorRead.
+          readWrite = ShaderAccessFlags::DescriptorRead;
         }
       }
     }

+ 46 - 33
lib/HLSL/ComputeViewIdState.cpp

@@ -42,11 +42,11 @@ DxilViewIdState::DxilViewIdState(DxilModule *pDxilModule)
     : m_pModule(pDxilModule) {}
 unsigned DxilViewIdState::getNumInputSigScalars() const                   { return m_NumInputSigScalars; }
 unsigned DxilViewIdState::getNumOutputSigScalars(unsigned StreamId) const { return m_NumOutputSigScalars[StreamId]; }
-unsigned DxilViewIdState::getNumPCSigScalars() const                      { return m_NumPCSigScalars; }
+unsigned DxilViewIdState::getNumPCSigScalars() const                      { return m_NumPCOrPrimSigScalars; }
 const DxilViewIdState::OutputsDependentOnViewIdType   &DxilViewIdState::getOutputsDependentOnViewId(unsigned StreamId) const    { return m_OutputsDependentOnViewId[StreamId]; }
-const DxilViewIdState::OutputsDependentOnViewIdType   &DxilViewIdState::getPCOutputsDependentOnViewId() const                   { return m_PCOutputsDependentOnViewId; }
+const DxilViewIdState::OutputsDependentOnViewIdType   &DxilViewIdState::getPCOutputsDependentOnViewId() const                   { return m_PCOrPrimOutputsDependentOnViewId; }
 const DxilViewIdState::InputsContributingToOutputType &DxilViewIdState::getInputsContributingToOutputs(unsigned StreamId) const { return m_InputsContributingToOutputs[StreamId]; }
-const DxilViewIdState::InputsContributingToOutputType &DxilViewIdState::getInputsContributingToPCOutputs() const                { return m_InputsContributingToPCOutputs; }
+const DxilViewIdState::InputsContributingToOutputType &DxilViewIdState::getInputsContributingToPCOutputs() const                { return m_InputsContributingToPCOrPrimOutputs; }
 const DxilViewIdState::InputsContributingToOutputType &DxilViewIdState::getPCInputsContributingToOutputs() const                { return m_PCInputsContributingToOutputs; }
 
 namespace {
@@ -95,39 +95,52 @@ void DxilViewIdState::PrintSets(llvm::raw_ostream &OS) {
   const ShaderModel *pSM = m_pModule->GetShaderModel();
   OS << "ViewId state: \n";
 
-  if (!pSM->IsGS()) {
-    OS << "Number of inputs: " << m_NumInputSigScalars     << 
-                 ", outputs: " << m_NumOutputSigScalars[0] << 
-              ", patchconst: " << m_NumPCSigScalars        << "\n";
-  } else {
+  if (pSM->IsGS()) {
     OS << "Number of inputs: "   << m_NumInputSigScalars     << 
                  ", outputs: { " << m_NumOutputSigScalars[0] << ", " << m_NumOutputSigScalars[1] << ", " <<
                                     m_NumOutputSigScalars[2] << ", " << m_NumOutputSigScalars[3] << " }" <<
-              ", patchconst: "   << m_NumPCSigScalars        << "\n";
+              ", patchconst: "   << m_NumPCOrPrimSigScalars        << "\n";
+  } else if (pSM->IsMS()) {
+    OS << "Number of inputs: " << m_NumInputSigScalars <<
+      ", vertex outputs: " << m_NumOutputSigScalars[0] <<
+      ", primitive outputs: " << m_NumPCOrPrimSigScalars << "\n";
+  } else {
+    OS << "Number of inputs: " << m_NumInputSigScalars <<
+      ", outputs: " << m_NumOutputSigScalars[0] <<
+      ", patchconst: " << m_NumPCOrPrimSigScalars << "\n";
   }
 
-  if (!pSM->IsGS()) {
-    PrintOutputsDependentOnViewId(OS, "Outputs", m_NumOutputSigScalars[0], m_OutputsDependentOnViewId[0]);
-  } else {
+  if (pSM->IsGS()) {
     PrintOutputsDependentOnViewId(OS, "Outputs for Stream0", m_NumOutputSigScalars[0], m_OutputsDependentOnViewId[0]);
     PrintOutputsDependentOnViewId(OS, "Outputs for Stream1", m_NumOutputSigScalars[1], m_OutputsDependentOnViewId[1]);
     PrintOutputsDependentOnViewId(OS, "Outputs for Stream2", m_NumOutputSigScalars[2], m_OutputsDependentOnViewId[2]);
     PrintOutputsDependentOnViewId(OS, "Outputs for Stream3", m_NumOutputSigScalars[3], m_OutputsDependentOnViewId[3]);
+  } else if (pSM->IsMS()) {
+    PrintOutputsDependentOnViewId(OS, "Vertex Outputs", m_NumOutputSigScalars[0], m_OutputsDependentOnViewId[0]);
+  } else {
+    PrintOutputsDependentOnViewId(OS, "Outputs", m_NumOutputSigScalars[0], m_OutputsDependentOnViewId[0]);
   }
+
   if (pSM->IsHS()) {
-    PrintOutputsDependentOnViewId(OS, "PCOutputs", m_NumPCSigScalars, m_PCOutputsDependentOnViewId);
+    PrintOutputsDependentOnViewId(OS, "PCOutputs", m_NumPCOrPrimSigScalars, m_PCOrPrimOutputsDependentOnViewId);
+  } else if (pSM->IsMS()) {
+    PrintOutputsDependentOnViewId(OS, "Primitive Outputs", m_NumPCOrPrimSigScalars, m_PCOrPrimOutputsDependentOnViewId);
   }
 
-  if (!pSM->IsGS()) {
-    PrintInputsContributingToOutputs(OS, "Inputs", "Outputs", m_InputsContributingToOutputs[0]);
-  } else {
+  if (pSM->IsGS()) {
     PrintInputsContributingToOutputs(OS, "Inputs", "Outputs for Stream0", m_InputsContributingToOutputs[0]);
     PrintInputsContributingToOutputs(OS, "Inputs", "Outputs for Stream1", m_InputsContributingToOutputs[1]);
     PrintInputsContributingToOutputs(OS, "Inputs", "Outputs for Stream2", m_InputsContributingToOutputs[2]);
     PrintInputsContributingToOutputs(OS, "Inputs", "Outputs for Stream3", m_InputsContributingToOutputs[3]);
+  } else if (pSM->IsMS()) {
+    PrintInputsContributingToOutputs(OS, "Inputs", "Vertex Outputs", m_InputsContributingToOutputs[0]);
+  } else {
+    PrintInputsContributingToOutputs(OS, "Inputs", "Outputs", m_InputsContributingToOutputs[0]);
   }
   if (pSM->IsHS()) {
-    PrintInputsContributingToOutputs(OS, "Inputs", "PCOutputs", m_InputsContributingToPCOutputs);
+    PrintInputsContributingToOutputs(OS, "Inputs", "PCOutputs", m_InputsContributingToPCOrPrimOutputs);
+  } else if (pSM->IsMS()) {
+    PrintInputsContributingToOutputs(OS, "Inputs", "Primitive Outputs", m_InputsContributingToPCOrPrimOutputs);
   } else if (pSM->IsDS()) {
     PrintInputsContributingToOutputs(OS, "PCInputs", "Outputs", m_PCInputsContributingToOutputs);
   }
@@ -141,9 +154,9 @@ void DxilViewIdState::Clear() {
     m_OutputsDependentOnViewId[i].reset();
     m_InputsContributingToOutputs[i].clear();
   }
-  m_NumPCSigScalars     = 0;
-  m_PCOutputsDependentOnViewId.reset();
-  m_InputsContributingToPCOutputs.clear();
+  m_NumPCOrPrimSigScalars     = 0;
+  m_PCOrPrimOutputsDependentOnViewId.reset();
+  m_InputsContributingToPCOrPrimOutputs.clear();
   m_PCInputsContributingToOutputs.clear();
   m_SerializedState.clear();
 }
@@ -212,15 +225,15 @@ void DxilViewIdState::Serialize() {
     }
     Size += NumInputs * NumOutUINTs; // m_InputsContributingToOutputs[StreamId]
   }
-  if (pSM->IsHS() || pSM->IsDS()) {
+  if (pSM->IsHS() || pSM->IsDS() || pSM->IsMS()) {
     Size += 1; // #PatchConstant.
     unsigned NumPCs = getNumPCSigScalars();
     unsigned NumPCUINTs = RoundUpToUINT(NumPCs);
-    if (pSM->IsHS()) {
+    if (pSM->IsHS() || pSM->IsMS()) {
       if (m_bUsesViewId) {
-        Size += NumPCUINTs; // m_PCOutputsDependentOnViewId
+        Size += NumPCUINTs; // m_PCOrPrimOutputsDependentOnViewId
       }
-      Size += NumInputs * NumPCUINTs; // m_InputsContributingToPCOutputs
+      Size += NumInputs * NumPCUINTs; // m_InputsContributingToPCOrPrimOutputs
     } else {
       unsigned NumOutputs = getNumOutputSigScalars(0);
       unsigned NumOutUINTs = RoundUpToUINT(NumOutputs);
@@ -244,16 +257,16 @@ void DxilViewIdState::Serialize() {
     SerializeInputsContributingToOutput(
         NumInputs, NumOutputs, m_InputsContributingToOutputs[StreamId], pData);
   }
-  if (pSM->IsHS() || pSM->IsDS()) {
+  if (pSM->IsHS() || pSM->IsDS() || pSM->IsMS()) {
     unsigned NumPCs = getNumPCSigScalars();
     *pData++ = NumPCs;
-    if (pSM->IsHS()) {
+    if (pSM->IsHS() || pSM->IsMS()) {
       if (m_bUsesViewId) {
-        SerializeOutputsDependentOnViewId(NumPCs, m_PCOutputsDependentOnViewId,
+        SerializeOutputsDependentOnViewId(NumPCs, m_PCOrPrimOutputsDependentOnViewId,
                                           pData);
       }
       SerializeInputsContributingToOutput(
-          NumInputs, NumPCs, m_InputsContributingToPCOutputs, pData);
+          NumInputs, NumPCs, m_InputsContributingToPCOrPrimOutputs, pData);
     } else {
       unsigned NumOutputs = getNumOutputSigScalars(0);
       SerializeInputsContributingToOutput(
@@ -348,18 +361,18 @@ void DxilViewIdState::Deserialize(const unsigned *pData,
         &pData[ConsumedUINTs], DataSizeInUINTs - ConsumedUINTs);
   }
 
-  if (pSM->IsHS() || pSM->IsDS()) {
+  if (pSM->IsHS() || pSM->IsDS() || pSM->IsMS()) {
     IFTBOOL(DataSizeInUINTs - ConsumedUINTs >= 1, DXC_E_GENERAL_INTERNAL_ERROR);
     unsigned NumPCs = pData[ConsumedUINTs++];
-    m_NumPCSigScalars = NumPCs;
-    if (pSM->IsHS()) {
+    m_NumPCOrPrimSigScalars = NumPCs;
+    if (pSM->IsHS() || pSM->IsMS()) {
       if (m_bUsesViewId) {
         ConsumedUINTs += DeserializeOutputsDependentOnViewId(
-            NumPCs, m_PCOutputsDependentOnViewId, &pData[ConsumedUINTs],
+            NumPCs, m_PCOrPrimOutputsDependentOnViewId, &pData[ConsumedUINTs],
             DataSizeInUINTs - ConsumedUINTs);
       }
       ConsumedUINTs += DeserializeInputsContributingToOutput(
-          NumInputs, NumPCs, m_InputsContributingToPCOutputs,
+          NumInputs, NumPCs, m_InputsContributingToPCOrPrimOutputs,
           &pData[ConsumedUINTs], DataSizeInUINTs - ConsumedUINTs);
     } else {
       unsigned NumOutputs = getNumOutputSigScalars(0);

+ 48 - 17
lib/HLSL/ComputeViewIdStateBuilder.cpp

@@ -7,7 +7,9 @@
 //                                                                           //
 ///////////////////////////////////////////////////////////////////////////////
 
+#include "dxc/HlslIntrinsicOp.h"
 #include "dxc/HLSL/ComputeViewIdState.h"
+#include "dxc/HLSL/HLOperations.h"
 #include "dxc/Support/Global.h"
 #include "dxc/DXIL/DxilModule.h"
 #include "dxc/DXIL/DxilOperations.h"
@@ -52,13 +54,13 @@ public:
         m_NumInputSigScalars(state.m_NumInputSigScalars),
         m_NumOutputSigScalars(state.m_NumOutputSigScalars,
                               DxilViewIdStateData::kNumStreams),
-        m_NumPCSigScalars(state.m_NumPCSigScalars),
+        m_NumPCOrPrimSigScalars(state.m_NumPCOrPrimSigScalars),
         m_OutputsDependentOnViewId(state.m_OutputsDependentOnViewId,
                                    DxilViewIdStateData::kNumStreams),
-        m_PCOutputsDependentOnViewId(state.m_PCOutputsDependentOnViewId),
+        m_PCOrPrimOutputsDependentOnViewId(state.m_PCOrPrimOutputsDependentOnViewId),
         m_InputsContributingToOutputs(state.m_InputsContributingToOutputs,
                                       DxilViewIdStateData::kNumStreams),
-        m_InputsContributingToPCOutputs(state.m_InputsContributingToPCOutputs),
+        m_InputsContributingToPCOrPrimOutputs(state.m_InputsContributingToPCOrPrimOutputs),
         m_PCInputsContributingToOutputs(state.m_PCInputsContributingToOutputs),
         m_bUsesViewId(state.m_bUsesViewId) {}
 
@@ -71,15 +73,15 @@ private:
 
   unsigned &m_NumInputSigScalars;
   MutableArrayRef<unsigned> m_NumOutputSigScalars;
-  unsigned &m_NumPCSigScalars;
+  unsigned &m_NumPCOrPrimSigScalars;
 
   // Set of scalar outputs dependent on ViewID.
   MutableArrayRef<OutputsDependentOnViewIdType> m_OutputsDependentOnViewId;
-  OutputsDependentOnViewIdType &m_PCOutputsDependentOnViewId;
+  OutputsDependentOnViewIdType &m_PCOrPrimOutputsDependentOnViewId;
 
   // Set of scalar inputs contributing to computation of scalar outputs.
   MutableArrayRef<InputsContributingToOutputType> m_InputsContributingToOutputs;
-  InputsContributingToOutputType &m_InputsContributingToPCOutputs; // HS PC only.
+  InputsContributingToOutputType &m_InputsContributingToPCOrPrimOutputs; // HS PC and MS Prim only.
   InputsContributingToOutputType &m_PCInputsContributingToOutputs; // DS only.
 
   bool &m_bUsesViewId;
@@ -174,7 +176,7 @@ void DxilViewIdStateBuilder::Compute() {
   // 1. Traverse signature MD to determine max packed location.
   DetermineMaxPackedLocation(m_pModule->GetInputSignature(), &m_NumInputSigScalars, 1);
   DetermineMaxPackedLocation(m_pModule->GetOutputSignature(), &m_NumOutputSigScalars[0], pSM->IsGS() ? kNumStreams : 1);
-  DetermineMaxPackedLocation(m_pModule->GetPatchConstantSignature(), &m_NumPCSigScalars, 1);
+  DetermineMaxPackedLocation(m_pModule->GetPatchConstOrPrimSignature(), &m_NumPCOrPrimSigScalars, 1);
 
   // 2. Collect sets of functions reachable from main and pc entries.
   CallGraphAnalysis CGA;
@@ -205,10 +207,10 @@ void DxilViewIdStateBuilder::Compute() {
                      m_OutputsDependentOnViewId[StreamId],
                      m_InputsContributingToOutputs[StreamId], false);
   }
-  if (pSM->IsHS()) {
+  if (pSM->IsHS() || pSM->IsMS()) {
     CreateViewIdSets(m_PCEntry.ContributingInstructions[0],
-                     m_PCOutputsDependentOnViewId,
-                     m_InputsContributingToPCOutputs, true);
+                     m_PCOrPrimOutputsDependentOnViewId,
+                     m_InputsContributingToPCOrPrimOutputs, true);
   } else if (pSM->IsDS()) {
     OutputsDependentOnViewIdType OutputsDependentOnViewId;
     CreateViewIdSets(m_Entry.ContributingInstructions[0],
@@ -233,12 +235,12 @@ void DxilViewIdStateBuilder::Clear() {
     m_OutputsDependentOnViewId[i].reset();
     m_InputsContributingToOutputs[i].clear();
   }
-  m_NumPCSigScalars     = 0;
+  m_NumPCOrPrimSigScalars     = 0;
   m_InpSigDynIdxElems.clear();
   m_OutSigDynIdxElems.clear();
   m_PCSigDynIdxElems.clear();
-  m_PCOutputsDependentOnViewId.reset();
-  m_InputsContributingToPCOutputs.clear();
+  m_PCOrPrimOutputsDependentOnViewId.reset();
+  m_InputsContributingToPCOrPrimOutputs.clear();
   m_PCInputsContributingToOutputs.clear();
   m_Entry.Clear();
   m_PCEntry.Clear();
@@ -340,6 +342,18 @@ void DxilViewIdStateBuilder::AnalyzeFunctions(EntryInfo &Entry) {
           GetUnsignedVal(SO.get_rowIndex(), (uint32_t*)&row);
           IFTBOOL(GetUnsignedVal(SO.get_colIndex(), &col), DXC_E_GENERAL_INTERNAL_ERROR);
           Entry.Outputs.emplace(CI);
+        } else if (DxilInst_StoreVertexOutput SVO = DxilInst_StoreVertexOutput(CI)) {
+          pDynIdxElems = &m_OutSigDynIdxElems;
+          IFTBOOL(GetUnsignedVal(SVO.get_outputSigId(), &id), DXC_E_GENERAL_INTERNAL_ERROR);
+          GetUnsignedVal(SVO.get_rowIndex(), (uint32_t*)&row);
+          IFTBOOL(GetUnsignedVal(SVO.get_colIndex(), &col), DXC_E_GENERAL_INTERNAL_ERROR);
+          Entry.Outputs.emplace(CI);
+        } else if (DxilInst_StorePrimitiveOutput SPO = DxilInst_StorePrimitiveOutput(CI)) {
+          pDynIdxElems = &m_PCSigDynIdxElems;
+          IFTBOOL(GetUnsignedVal(SPO.get_outputSigId(), &id), DXC_E_GENERAL_INTERNAL_ERROR);
+          GetUnsignedVal(SPO.get_rowIndex(), (uint32_t*)&row);
+          IFTBOOL(GetUnsignedVal(SPO.get_colIndex(), &col), DXC_E_GENERAL_INTERNAL_ERROR);
+          Entry.Outputs.emplace(CI);
         } else if (DxilInst_LoadPatchConstant LPC = DxilInst_LoadPatchConstant(CI)) {
           if (m_pModule->GetShaderModel()->IsDS()) {
             pDynIdxElems = &m_PCSigDynIdxElems;
@@ -412,8 +426,20 @@ void DxilViewIdStateBuilder::CollectValuesContributingToOutputs(EntryInfo &Entry
       GetUnsignedVal(SO.get_outputSigId(), &id);
       GetUnsignedVal(SO.get_colIndex(), &col);
       GetUnsignedVal(SO.get_rowIndex(), (uint32_t*)&startRow);
+    } else if (DxilInst_StoreVertexOutput SVO = DxilInst_StoreVertexOutput(CI)) {
+      pDxilSig = &m_pModule->GetPatchConstOrPrimSignature();
+      pContributingValue = SVO.get_value();
+      GetUnsignedVal(SVO.get_outputSigId(), &id);
+      GetUnsignedVal(SVO.get_colIndex(), &col);
+      GetUnsignedVal(SVO.get_rowIndex(), (uint32_t*)&startRow);
+    } else if (DxilInst_StorePrimitiveOutput SPO = DxilInst_StorePrimitiveOutput(CI)) {
+      pDxilSig = &m_pModule->GetPatchConstOrPrimSignature();
+      pContributingValue = SPO.get_value();
+      GetUnsignedVal(SPO.get_outputSigId(), &id);
+      GetUnsignedVal(SPO.get_colIndex(), &col);
+      GetUnsignedVal(SPO.get_rowIndex(), (uint32_t*)&startRow);
     } else if (DxilInst_StorePatchConstant SPC = DxilInst_StorePatchConstant(CI)) {
-      pDxilSig = &m_pModule->GetPatchConstantSignature();
+      pDxilSig = &m_pModule->GetPatchConstOrPrimSignature();
       pContributingValue = SPC.get_value();
       GetUnsignedVal(SPC.get_outputSigID(), &id);
       GetUnsignedVal(SPC.get_row(), (uint32_t*)&startRow);
@@ -669,6 +695,11 @@ void DxilViewIdStateBuilder::CollectReachingDeclsRec(Value *pValue, ValueSetType
     CollectReachingDeclsRec(SelI->getFalseValue(), ReachingDecls, Visited);
   } else if (dyn_cast<Argument>(pValue)) {
     ReachingDecls.emplace(pValue);
+  } else if (CallInst *call = dyn_cast<CallInst>(pValue)) {
+    Function *func = call->getCalledFunction();
+    StringRef funcName = func->getName();
+    DXASSERT(funcName.startswith("dx.op.getMeshPayload"), "the function must be @dx.op.getMeshPayload here.");
+    ReachingDecls.emplace(pValue);
   } else {
     IFT(DXC_E_GENERAL_INTERNAL_ERROR);
   }
@@ -762,7 +793,7 @@ void DxilViewIdStateBuilder::CreateViewIdSets(const std::unordered_map<unsigned,
           GetUnsignedVal(LPC.get_inputSigId(), &inpId);
           GetUnsignedVal(LPC.get_col(), &col);
           GetUnsignedVal(LPC.get_row(), (uint32_t*)&startRow);
-          pSigElem = &m_pModule->GetPatchConstantSignature().GetElement(inpId);
+          pSigElem = &m_pModule->GetPatchConstOrPrimSignature().GetElement(inpId);
         }
       } else {
         continue;
@@ -787,7 +818,7 @@ void DxilViewIdStateBuilder::CreateViewIdSets(const std::unordered_map<unsigned,
             // This HS patch-constant output depends on an input value of LoadOutputControlPoint
             // that is the output value of the HS main (control-point) function.
             // Transitively update this (patch-constant) output dependence on main (control-point) output.
-            DXASSERT_NOMSG(&OutputsDependentOnViewId == &m_PCOutputsDependentOnViewId);
+            DXASSERT_NOMSG(&OutputsDependentOnViewId == &m_PCOrPrimOutputsDependentOnViewId);
             OutputsDependentOnViewId[outIdx] = OutputsDependentOnViewId[outIdx] || m_OutputsDependentOnViewId[0][index];
 
             const auto it = m_InputsContributingToOutputs[0].find(index);
@@ -813,7 +844,7 @@ unsigned DxilViewIdStateBuilder::GetLinearIndex(DxilSignatureElement &SigElem, i
 void DxilViewIdStateBuilder::UpdateDynamicIndexUsageState() const {
   UpdateDynamicIndexUsageStateForSig(m_pModule->GetInputSignature(), m_InpSigDynIdxElems);
   UpdateDynamicIndexUsageStateForSig(m_pModule->GetOutputSignature(), m_OutSigDynIdxElems);
-  UpdateDynamicIndexUsageStateForSig(m_pModule->GetPatchConstantSignature(), m_PCSigDynIdxElems);
+  UpdateDynamicIndexUsageStateForSig(m_pModule->GetPatchConstOrPrimSignature(), m_PCSigDynIdxElems);
 }
 
 void DxilViewIdStateBuilder::UpdateDynamicIndexUsageStateForSig(DxilSignature &Sig,

+ 11 - 18
lib/HLSL/DxilContainerReflection.cpp

@@ -1854,7 +1854,7 @@ HRESULT DxilShaderReflection::Load(IDxcBlob *pBlob,
     // Populate input/output/patch constant signatures.
     CreateReflectionObjectsForSignature(m_pDxilModule->GetInputSignature(), m_InputSignature);
     CreateReflectionObjectsForSignature(m_pDxilModule->GetOutputSignature(), m_OutputSignature);
-    CreateReflectionObjectsForSignature(m_pDxilModule->GetPatchConstantSignature(), m_PatchConstantSignature);
+    CreateReflectionObjectsForSignature(m_pDxilModule->GetPatchConstOrPrimSignature(), m_PatchConstantSignature);
     MarkUsedSignatureElements();
     return S_OK;
   }
@@ -1969,14 +1969,14 @@ void DxilShaderReflection::MarkUsedSignatureElements() {
       if (!GetUnsignedVal(SPC.get_col(), &col)) continue;
       if (!GetUnsignedVal(SPC.get_row(), &row)) continue;
       pDescs = &m_PatchConstantSignature;
-      pSig = &m_pDxilModule->GetPatchConstantSignature();
+      pSig = &m_pDxilModule->GetPatchConstOrPrimSignature();
     }
     else if (LPC) {
       if (!GetUnsignedVal(LPC.get_inputSigId(), &sigId)) continue;
       if (!GetUnsignedVal(LPC.get_col(), &col)) continue;
       if (!GetUnsignedVal(LPC.get_row(), &row)) continue;
       pDescs = &m_PatchConstantSignature;
-      pSig = &m_pDxilModule->GetPatchConstantSignature();
+      pSig = &m_pDxilModule->GetPatchConstOrPrimSignature();
     }
     else {
       continue;
@@ -2171,21 +2171,14 @@ UINT DxilShaderReflection::GetThreadGroupSize(UINT *pSizeX, UINT *pSizeY, UINT *
 }
 
 UINT64 DxilShaderReflection::GetRequiresFlags() {
-  UINT64 result = 0;
-  uint64_t features = m_pDxilModule->m_ShaderFlags.GetFeatureInfo();
-  if (features & ShaderFeatureInfo_Doubles) result |= D3D_SHADER_REQUIRES_DOUBLES;
-  if (features & ShaderFeatureInfo_UAVsAtEveryStage) result |= D3D_SHADER_REQUIRES_UAVS_AT_EVERY_STAGE;
-  if (features & ShaderFeatureInfo_64UAVs) result |= D3D_SHADER_REQUIRES_64_UAVS;
-  if (features & ShaderFeatureInfo_MinimumPrecision) result |= D3D_SHADER_REQUIRES_MINIMUM_PRECISION;
-  if (features & ShaderFeatureInfo_11_1_DoubleExtensions) result |= D3D_SHADER_REQUIRES_11_1_DOUBLE_EXTENSIONS;
-  if (features & ShaderFeatureInfo_11_1_ShaderExtensions) result |= D3D_SHADER_REQUIRES_11_1_SHADER_EXTENSIONS;
-  if (features & ShaderFeatureInfo_LEVEL9ComparisonFiltering) result |= D3D_SHADER_REQUIRES_LEVEL_9_COMPARISON_FILTERING;
-  if (features & ShaderFeatureInfo_TiledResources) result |= D3D_SHADER_REQUIRES_TILED_RESOURCES;
-  if (features & ShaderFeatureInfo_StencilRef) result |= D3D_SHADER_REQUIRES_STENCIL_REF;
-  if (features & ShaderFeatureInfo_InnerCoverage) result |= D3D_SHADER_REQUIRES_INNER_COVERAGE;
-  if (features & ShaderFeatureInfo_TypedUAVLoadAdditionalFormats) result |= D3D_SHADER_REQUIRES_TYPED_UAV_LOAD_ADDITIONAL_FORMATS;
-  if (features & ShaderFeatureInfo_ROVs) result |= D3D_SHADER_REQUIRES_ROVS;
-  if (features & ShaderFeatureInfo_ViewportAndRTArrayIndexFromAnyShaderFeedingRasterizer) result |= D3D_SHADER_REQUIRES_VIEWPORT_AND_RT_ARRAY_INDEX_FROM_ANY_SHADER_FEEDING_RASTERIZER;
+  UINT64 result = m_pDxilModule->m_ShaderFlags.GetFeatureInfo();
+  // FeatureInfo flags are identical, with the exception of a collision between:
+  // SHADER_FEATURE_COMPUTE_SHADERS_PLUS_RAW_AND_STRUCTURED_BUFFERS_VIA_SHADER_4_X
+  // and D3D_SHADER_REQUIRES_EARLY_DEPTH_STENCIL
+  // We keep track of the flag elsewhere, so use that instead.
+  result &= ~(UINT64)D3D_SHADER_REQUIRES_EARLY_DEPTH_STENCIL;
+  if (m_pDxilModule->m_ShaderFlags.GetForceEarlyDepthStencil())
+    result |= D3D_SHADER_REQUIRES_EARLY_DEPTH_STENCIL;
   return result;
 }
 

+ 17 - 6
lib/HLSL/DxilPreserveAllOutputs.cpp

@@ -31,7 +31,10 @@ public:
   explicit OutputWrite(CallInst *call)
     : m_Call(call)
   {
-    assert(DxilInst_StoreOutput(call) || DxilInst_StorePatchConstant(call));
+    assert(DxilInst_StoreOutput(call) ||
+           DxilInst_StoreVertexOutput(call) ||
+           DxilInst_StorePrimitiveOutput(call) ||
+           DxilInst_StorePatchConstant(call));
   }
 
   unsigned GetSignatureID() const {
@@ -40,8 +43,8 @@ public:
   }
 
   DxilSignatureElement &GetSignatureElement(DxilModule &DM) const {
-    if (DxilInst_StorePatchConstant(m_Call))
-      return DM.GetPatchConstantSignature().GetElement(GetSignatureID());
+    if (DxilInst_StorePatchConstant(m_Call) || DxilInst_StorePrimitiveOutput(m_Call))
+      return DM.GetPatchConstOrPrimSignature().GetElement(GetSignatureID());
     else
       return DM.GetOutputSignature().GetElement(GetSignatureID());
   }
@@ -169,8 +172,14 @@ private:
   }
 
   DXIL::OpCode GetOutputOpCode() const {
-    if (m_OutputElement.IsPatchConstant())
-      return DXIL::OpCode::StorePatchConstant;
+    if (m_OutputElement.IsPatchConstOrPrim()) {
+      if (m_OutputElement.GetSigPointKind() == DXIL::SigPointKind::PCOut)
+        return DXIL::OpCode::StorePatchConstant;
+      else {
+        assert(m_OutputElement.GetSigPointKind() == DXIL::SigPointKind::MSPOut);
+        return DXIL::OpCode::StorePrimitiveOutput;
+      }
+    }
     else
       return DXIL::OpCode::StoreOutput;
   }
@@ -228,9 +237,11 @@ DxilPreserveAllOutputs::OutputVec DxilPreserveAllOutputs::collectOutputStores(Fu
   for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) {
     Instruction *inst = &*I;
     DxilInst_StoreOutput storeOutput(inst);
+    DxilInst_StoreVertexOutput storeVertexOutput(inst);
+    DxilInst_StorePrimitiveOutput storePrimitiveOutput(inst);
     DxilInst_StorePatchConstant storePatch(inst);
 
-    if (storeOutput || storePatch)
+    if (storeOutput || storeVertexOutput || storePrimitiveOutput || storePatch)
       calls.emplace_back(cast<CallInst>(inst));
   }
   return calls;

+ 467 - 43
lib/HLSL/DxilValidation.cpp

@@ -191,6 +191,12 @@ const char *hlsl::GetValidationRuleText(ValidationRule value) {
     case hlsl::ValidationRule::InstrAttributeAtVertexNoInterpolation: return "Attribute %0 must have nointerpolation mode in order to use GetAttributeAtVertex function.";
     case hlsl::ValidationRule::InstrCreateHandleImmRangeID: return "Local resource must map to global resource.";
     case hlsl::ValidationRule::InstrSignatureOperationNotInEntry: return "Dxil operation for input output signature must be in entryPoints.";
+    case hlsl::ValidationRule::InstrMultipleSetMeshOutputCounts: return "SetMeshOUtputCounts cannot be called multiple times.";
+    case hlsl::ValidationRule::InstrMissingSetMeshOutputCounts: return "Missing SetMeshOutputCounts call.";
+    case hlsl::ValidationRule::InstrNonDominatingSetMeshOutputCounts: return "Non-Dominating SetMeshOutputCounts call.";
+    case hlsl::ValidationRule::InstrMultipleGetMeshPayload: return "GetMeshPayload cannot be called multiple times.";
+    case hlsl::ValidationRule::InstrNotOnceDispatchMesh: return "DispatchMesh must be called exactly once in an Amplification shader.";
+    case hlsl::ValidationRule::InstrNonDominatingDispatchMesh: return "Non-Dominating DispatchMesh call.";
     case hlsl::ValidationRule::TypesNoVector: return "Vector type '%0' is not allowed";
     case hlsl::ValidationRule::TypesDefined: return "Type '%0' is not defined on DXIL primitives";
     case hlsl::ValidationRule::TypesIntWidth: return "Int type '%0' has an invalid width";
@@ -202,6 +208,7 @@ const char *hlsl::GetValidationRuleText(ValidationRule value) {
     case hlsl::ValidationRule::SmOperand: return "Operand must be defined in target shader model";
     case hlsl::ValidationRule::SmSemantic: return "Semantic '%0' is invalid as %1 %2";
     case hlsl::ValidationRule::SmNoInterpMode: return "Interpolation mode for '%0' is set but should be undefined";
+    case hlsl::ValidationRule::SmConstantInterpMode: return "Interpolation mode for '%0' should be constant";
     case hlsl::ValidationRule::SmNoPSOutputIdx: return "Pixel shader output registers are not indexable.";
     case hlsl::ValidationRule::SmPSConsistentInterp: return "Interpolation mode for PS input position must be linear_noperspective_centroid or linear_noperspective_sample when outputting oDepthGE or oDepthLE and not running at sample frequency (which is forced by inputting SV_SampleIndex or declaring an input linear_sample or linear_noperspective_sample)";
     case hlsl::ValidationRule::SmThreadGroupChannelRange: return "Declared Thread Group %0 size %1 outside valid range [%2..%3]";
@@ -253,6 +260,16 @@ const char *hlsl::GetValidationRuleText(ValidationRule value) {
     case hlsl::ValidationRule::Sm64bitRawBufferLoadStore: return "i64/f64 rawBufferLoad/Store overloads are allowed after SM 6.3";
     case hlsl::ValidationRule::SmRayShaderSignatures: return "Ray tracing shader '%0' should not have any shader signatures";
     case hlsl::ValidationRule::SmRayShaderPayloadSize: return "For shader '%0', %1 size is smaller than argument's allocation size";
+    case hlsl::ValidationRule::SmMeshShaderMaxVertexCount: return "MS max vertex output count must be [0..%0].  %1 specified";
+    case hlsl::ValidationRule::SmMeshShaderMaxPrimitiveCount: return "MS max primitive output count must be [0..%0].  %1 specified";
+    case hlsl::ValidationRule::SmMeshShaderPayloadSize: return "For shader '%0', payload size is greater than %1";
+    case hlsl::ValidationRule::SmMeshShaderOutputSize: return "For shader '%0', vertex plus primitive output size is greater than %1";
+    case hlsl::ValidationRule::SmMeshShaderInOutSize: return "For shader '%0', input plus output size is greater than %1";
+    case hlsl::ValidationRule::SmMeshVSigRowCount: return "For shader '%0', vertex output signatures are taking up more than %1 rows";
+    case hlsl::ValidationRule::SmMeshPSigRowCount: return "For shader '%0', primitive output signatures are taking up more than %1 rows";
+    case hlsl::ValidationRule::SmMeshTotalSigRowCount: return "For shader '%0', vertex and primitive output signatures are taking up more than %1 rows";
+    case hlsl::ValidationRule::SmMaxMSSMSize: return "Total Thread Group Shared Memory storage is %0, exceeded %1";
+    case hlsl::ValidationRule::SmAmplificationShaderPayloadSize: return "For shader '%0', payload size is greater than %1";
     case hlsl::ValidationRule::UniNoWaveSensitiveGradient: return "Gradient operations are not affected by wave-sensitive data or control flow.";
     case hlsl::ValidationRule::FlowReducible: return "Execution flow must be reducible";
     case hlsl::ValidationRule::FlowNoRecusion: return "Recursion is not permitted";
@@ -356,7 +373,7 @@ struct EntryStatus {
   bool hasOutputPosition[DXIL::kNumOutputStreams];
   unsigned OutputPositionMask[DXIL::kNumOutputStreams];
   std::vector<unsigned> outputCols;
-  std::vector<unsigned> patchConstCols;
+  std::vector<unsigned> patchConstOrPrimCols;
   bool m_bCoverageIn, m_bInnerCoverageIn;
   bool hasViewID;
   unsigned domainLocSize;
@@ -368,8 +385,8 @@ struct EntryStatus {
     }
 
     outputCols.resize(entryProps.sig.OutputSignature.GetElements().size(), 0);
-    patchConstCols.resize(
-        entryProps.sig.PatchConstantSignature.GetElements().size(), 0);
+    patchConstOrPrimCols.resize(
+        entryProps.sig.PatchConstOrPrimSignature.GetElements().size(), 0);
   }
 };
 
@@ -781,7 +798,7 @@ static bool ValidateOpcodeInProfile(DXIL::OpCode opcode,
   // Instructions: ThreadId=93, GroupId=94, ThreadIdInGroup=95,
   // FlattenedThreadIdInGroup=96
   if ((93 <= op && op <= 96))
-    return (SK == DXIL::ShaderKind::Compute);
+    return (SK == DXIL::ShaderKind::Compute || SK == DXIL::ShaderKind::Mesh || SK == DXIL::ShaderKind::Amplification);
   // Instructions: DomainLocation=105
   if (op == 105)
     return (SK == DXIL::ShaderKind::Domain);
@@ -815,7 +832,7 @@ static bool ValidateOpcodeInProfile(DXIL::OpCode opcode,
   // Instructions: ViewID=138
   if (op == 138)
     return (major > 6 || (major == 6 && minor >= 1))
-        && (SK == DXIL::ShaderKind::Vertex || SK == DXIL::ShaderKind::Hull || SK == DXIL::ShaderKind::Domain || SK == DXIL::ShaderKind::Geometry || SK == DXIL::ShaderKind::Pixel);
+        && (SK == DXIL::ShaderKind::Vertex || SK == DXIL::ShaderKind::Hull || SK == DXIL::ShaderKind::Domain || SK == DXIL::ShaderKind::Geometry || SK == DXIL::ShaderKind::Pixel || SK == DXIL::ShaderKind::Mesh);
   // Instructions: RawBufferLoad=139, RawBufferStore=140
   if ((139 <= op && op <= 140))
     return (major > 6 || (major == 6 && minor >= 2));
@@ -857,9 +874,49 @@ static bool ValidateOpcodeInProfile(DXIL::OpCode opcode,
   if ((162 <= op && op <= 164))
     return (major > 6 || (major == 6 && minor >= 4));
   // Instructions: WaveMatch=165, WaveMultiPrefixOp=166,
-  // WaveMultiPrefixBitCount=167
-  if ((165 <= op && op <= 167))
+  // WaveMultiPrefixBitCount=167, WriteSamplerFeedbackLevel=176,
+  // WriteSamplerFeedbackGrad=177, AllocateRayQuery=178,
+  // RayQuery_TraceRayInline=179, RayQuery_Proceed=180, RayQuery_Abort=181,
+  // RayQuery_CommitNonOpaqueTriangleHit=182,
+  // RayQuery_CommitProceduralPrimitiveHit=183, RayQuery_CommittedStatus=184,
+  // RayQuery_CandidateType=185, RayQuery_CandidateObjectToWorld3x4=186,
+  // RayQuery_CandidateWorldToObject3x4=187,
+  // RayQuery_CommittedObjectToWorld3x4=188,
+  // RayQuery_CommittedWorldToObject3x4=189,
+  // RayQuery_CandidateProceduralPrimitiveNonOpaque=190,
+  // RayQuery_CandidateTriangleFrontFace=191,
+  // RayQuery_CommittedTriangleFrontFace=192,
+  // RayQuery_CandidateTriangleBarycentrics=193,
+  // RayQuery_CommittedTriangleBarycentrics=194, RayQuery_RayFlags=195,
+  // RayQuery_WorldRayOrigin=196, RayQuery_WorldRayDirection=197,
+  // RayQuery_RayTMin=198, RayQuery_CandidateTriangleRayT=199,
+  // RayQuery_CommittedRayT=200, RayQuery_CandidateInstanceIndex=201,
+  // RayQuery_CandidateInstanceID=202, RayQuery_CandidateGeometryIndex=203,
+  // RayQuery_CandidatePrimitiveIndex=204, RayQuery_CandidateObjectRayOrigin=205,
+  // RayQuery_CandidateObjectRayDirection=206,
+  // RayQuery_CommittedInstanceIndex=207, RayQuery_CommittedInstanceID=208,
+  // RayQuery_CommittedGeometryIndex=209, RayQuery_CommittedPrimitiveIndex=210,
+  // RayQuery_CommittedObjectRayOrigin=211,
+  // RayQuery_CommittedObjectRayDirection=212
+  if ((165 <= op && op <= 167) || (176 <= op && op <= 212))
     return (major > 6 || (major == 6 && minor >= 5));
+  // Instructions: DispatchMesh=173
+  if (op == 173)
+    return (major > 6 || (major == 6 && minor >= 5))
+        && (SK == DXIL::ShaderKind::Amplification);
+  // Instructions: GeometryIndex=213
+  if (op == 213)
+    return (major > 6 || (major == 6 && minor >= 5))
+        && (SK == DXIL::ShaderKind::Library || SK == DXIL::ShaderKind::Intersection || SK == DXIL::ShaderKind::AnyHit || SK == DXIL::ShaderKind::ClosestHit);
+  // Instructions: WriteSamplerFeedback=174, WriteSamplerFeedbackBias=175
+  if ((174 <= op && op <= 175))
+    return (major > 6 || (major == 6 && minor >= 5))
+        && (SK == DXIL::ShaderKind::Library || SK == DXIL::ShaderKind::Pixel);
+  // Instructions: SetMeshOutputCounts=168, EmitIndices=169, GetMeshPayload=170,
+  // StoreVertexOutput=171, StorePrimitiveOutput=172
+  if ((168 <= op && op <= 172))
+    return (major > 6 || (major == 6 && minor >= 5))
+        && (SK == DXIL::ShaderKind::Mesh);
   return true;
   // VALOPCODESM-TEXT:END
 }
@@ -889,8 +946,8 @@ static unsigned ValidateSignatureRowCol(Instruction *I,
   } else {
     if (SE.IsOutput())
       Status.outputCols[SE.GetID()] |= 1 << col;
-    if (SE.IsPatchConstant())
-      Status.patchConstCols[SE.GetID()] |= 1 << col;
+    if (SE.IsPatchConstOrPrim())
+      Status.patchConstOrPrimCols[SE.GetID()] |= 1 << col;
   }
 
   return col;
@@ -1463,10 +1520,13 @@ static void ValidateSignatureDxilOp(CallInst *CI, DXIL::OpCode opcode,
       }
     }
   } break;
-  case DXIL::OpCode::StoreOutput: {
+  case DXIL::OpCode::StoreOutput:
+  case DXIL::OpCode::StoreVertexOutput: 
+  case DXIL::OpCode::StorePrimitiveOutput: {
     Value *outputID =
         CI->getArgOperand(DXIL::OperandIndex::kStoreOutputIDOpIdx);
-    DxilSignature &outputSig = S.OutputSignature;
+    DxilSignature &outputSig = opcode == DXIL::OpCode::StorePrimitiveOutput ?
+      S.PatchConstOrPrimSignature : S.OutputSignature;
     Value *row = CI->getArgOperand(DXIL::OperandIndex::kStoreOutputRowOpIdx);
     Value *col = CI->getArgOperand(DXIL::OperandIndex::kStoreOutputColOpIdx);
     ValidateSignatureAccess(CI, outputSig, outputID, row, col, Status, ValCtx);
@@ -1509,7 +1569,7 @@ static void ValidateSignatureDxilOp(CallInst *CI, DXIL::OpCode opcode,
         DxilEntrySignature &S = EntryProps.sig;
         Value *outputID =
             CI->getArgOperand(DXIL::OperandIndex::kStoreOutputIDOpIdx);
-        DxilSignature &outputSig = S.PatchConstantSignature;
+        DxilSignature &outputSig = S.PatchConstOrPrimSignature;
         Value *row =
             CI->getArgOperand(DXIL::OperandIndex::kStoreOutputRowOpIdx);
         Value *col =
@@ -1615,6 +1675,30 @@ static void ValidateSignatureDxilOp(CallInst *CI, DXIL::OpCode opcode,
                                   {"Emit/CutStream", "Geometry shader"});
     }
   } break;
+  case DXIL::OpCode::EmitIndices: {
+    if (!props.IsMS()) {
+      ValCtx.EmitInstrFormatError(CI, ValidationRule::SmOpcodeInInvalidFunction,
+                                  {"EmitIndices", "Mesh shader"});
+    }
+  } break;
+  case DXIL::OpCode::SetMeshOutputCounts: {
+    if (!props.IsMS()) {
+      ValCtx.EmitInstrFormatError(CI, ValidationRule::SmOpcodeInInvalidFunction,
+                                  {"SetMeshOutputCounts", "Mesh shader"});
+    }
+  } break;
+  case DXIL::OpCode::GetMeshPayload: {
+    if (!props.IsMS()) {
+      ValCtx.EmitInstrFormatError(CI, ValidationRule::SmOpcodeInInvalidFunction,
+                                  {"GetMeshPayload", "Mesh shader"});
+    }
+  } break;
+  case DXIL::OpCode::DispatchMesh: {
+    if (!props.IsAS()) {
+      ValCtx.EmitInstrFormatError(CI, ValidationRule::SmOpcodeInInvalidFunction,
+                                  {"DispatchMesh", "Amplification shader"});
+    }
+  } break;
   default:
     break;
   }
@@ -2284,6 +2368,8 @@ static void ValidateDxilOperationCallInProfile(CallInst *CI,
   case DXIL::OpCode::LoadInput:
   case DXIL::OpCode::DomainLocation:
   case DXIL::OpCode::StoreOutput:
+  case DXIL::OpCode::StoreVertexOutput:
+  case DXIL::OpCode::StorePrimitiveOutput:
   case DXIL::OpCode::OutputControlPointID:
   case DXIL::OpCode::LoadOutputControlPoint:
   case DXIL::OpCode::StorePatchConstant:
@@ -2700,6 +2786,110 @@ static void ValidateGradientOps(Function *F, ArrayRef<CallInst *> ops, ArrayRef<
   }
 }
 
+static void ValidateMsIntrinsics(Function *F,
+                                 ValidationContext &ValCtx,
+                                 CallInst *setMeshOutputCounts,
+                                 CallInst *getMeshPayload) {
+  if (ValCtx.DxilMod.HasDxilFunctionProps(F)) {
+    DXIL::ShaderKind shaderKind = ValCtx.DxilMod.GetDxilFunctionProps(F).shaderKind;
+    if (shaderKind != DXIL::ShaderKind::Mesh)
+      return;
+  } else {
+    return;
+  }
+
+  DominatorTreeAnalysis DTA;
+  DominatorTree DT = DTA.run(*F);
+
+  for (auto b = F->begin(), bend = F->end(); b != bend; ++b) {
+    bool foundSetMeshOutputCountsInCurrentBB = false;
+    for (auto i = b->begin(), iend = b->end(); i != iend; ++i) {
+      llvm::Instruction &I = *i;
+
+      // Calls to external functions.
+      CallInst *CI = dyn_cast<CallInst>(&I);
+      if (CI) {
+        Function *FCalled = CI->getCalledFunction();
+        if (FCalled->isDeclaration()) {
+          // External function validation will diagnose.
+          if (!IsDxilFunction(FCalled)) {
+            continue;
+          }
+
+          if (CI == setMeshOutputCounts) {
+            foundSetMeshOutputCountsInCurrentBB = true;
+          }
+          Value *opcodeVal = CI->getOperand(0);
+          ConstantInt *OpcodeConst = dyn_cast<ConstantInt>(opcodeVal);
+          unsigned opcode = OpcodeConst->getLimitedValue();
+          DXIL::OpCode dxilOpcode = (DXIL::OpCode)opcode;
+
+          if (dxilOpcode == DXIL::OpCode::StoreVertexOutput ||
+              dxilOpcode == DXIL::OpCode::StorePrimitiveOutput ||
+              dxilOpcode == DXIL::OpCode::EmitIndices) {
+            if (setMeshOutputCounts == nullptr) {
+              ValCtx.EmitInstrError(&I, ValidationRule::InstrMissingSetMeshOutputCounts);
+            } else if (!foundSetMeshOutputCountsInCurrentBB &&
+                       !DT.dominates(setMeshOutputCounts->getParent(), I.getParent())) {
+              ValCtx.EmitInstrError(&I, ValidationRule::InstrNonDominatingSetMeshOutputCounts);
+            }
+          }
+        }
+      }
+    }
+  }
+
+  if (getMeshPayload) {
+    PointerType *payloadPTy = cast<PointerType>(getMeshPayload->getType());
+    StructType *payloadTy = cast<StructType>(payloadPTy->getPointerElementType());
+    const DataLayout &DL = F->getParent()->getDataLayout();
+    unsigned payloadSize = DL.getTypeAllocSize(payloadTy);
+
+    if (payloadSize > DXIL::kMaxMSASPayloadSize) {
+      ValCtx.EmitFormatError(ValidationRule::SmMeshShaderPayloadSize,
+        { F->getName(), std::to_string(DXIL::kMaxMSASPayloadSize) });
+    }
+
+    DxilFunctionProps &prop = ValCtx.DxilMod.GetDxilFunctionProps(F);
+    prop.ShaderProps.MS.payloadByteSize = payloadSize;
+  }
+}
+
+static void ValidateAsIntrinsics(Function *F, ValidationContext &ValCtx, CallInst *dispatchMesh) {
+  if (ValCtx.DxilMod.HasDxilFunctionProps(F)) {
+    DXIL::ShaderKind shaderKind = ValCtx.DxilMod.GetDxilFunctionProps(F).shaderKind;
+    if (shaderKind != DXIL::ShaderKind::Amplification)
+      return;
+  }
+  else {
+    return;
+  }
+
+  if (dispatchMesh == nullptr) {
+    ValCtx.EmitError(ValidationRule::InstrNotOnceDispatchMesh);
+    return;
+  }
+
+  PostDominatorTree PDT;
+  PDT.runOnFunction(*F);
+
+  if (!PDT.dominates(dispatchMesh->getParent(), &F->getEntryBlock())) {
+    ValCtx.EmitInstrError(dispatchMesh, ValidationRule::InstrNonDominatingDispatchMesh);
+  }
+
+  Function *dispatchMeshFunc = dispatchMesh->getCalledFunction();
+  FunctionType *dispatchMeshFuncTy = dispatchMeshFunc->getFunctionType();
+  PointerType *payloadPTy = cast<PointerType>(dispatchMeshFuncTy->getParamType(4));
+  StructType *payloadTy = cast<StructType>(payloadPTy->getPointerElementType());
+  const DataLayout &DL = F->getParent()->getDataLayout();
+  unsigned payloadSize = DL.getTypeAllocSize(payloadTy);
+
+  if (payloadSize > DXIL::kMaxMSASPayloadSize) {
+    ValCtx.EmitFormatError(ValidationRule::SmAmplificationShaderPayloadSize,
+      { F->getName(), std::to_string(DXIL::kMaxMSASPayloadSize) });
+  }
+}
+
 static void ValidateControlFlowHint(BasicBlock &bb, ValidationContext &ValCtx) {
   // Validate controlflow hint.
   TerminatorInst *TI = bb.getTerminator();
@@ -2957,11 +3147,49 @@ static bool IsLLVMInstructionAllowedForLib(Instruction &I, ValidationContext &Va
   }
 }
 
+static bool IsFromMeshPayload(Instruction *I) {
+  unsigned opcode = I->getOpcode();
+  switch (opcode) {
+  case Instruction::Alloca: {
+    break;
+  }
+  case Instruction::GetElementPtr: {
+    Value *src0 = I->getOperand(0);
+    if (I = dyn_cast<Instruction>(src0)) {
+      return IsFromMeshPayload(I);
+    }
+    return false;
+  }
+  case Instruction::Store: {
+    Value *src1 = I->getOperand(1);
+    if (I = dyn_cast<Instruction>(src1)) {
+      return IsFromMeshPayload(I);
+    }
+    return false;
+  }
+  default:
+    return false;
+  }
+
+  for (auto user : I->users()) {
+    if (CallInst *CI = dyn_cast<CallInst>(user)) {
+      Function *func = CI->getCalledFunction();
+      StringRef funcName = func->getName();
+      if (funcName.startswith("dx.op.dispatchMesh"))
+        return true;
+    }
+  }
+  return false;
+}
+
 static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
   bool SupportsMinPrecision =
       ValCtx.DxilMod.GetGlobalFlags() & DXIL::kEnableMinPrecision;
   SmallVector<CallInst *, 16> gradientOps;
   SmallVector<CallInst *, 16> barriers;
+  CallInst *setMeshOutputCounts = nullptr;
+  CallInst *getMeshPayload = nullptr;
+  CallInst *dispatchMesh = nullptr;
   for (auto b = F->begin(), bend = F->end(); b != bend; ++b) {
     for (auto i = b->begin(), iend = b->end(); i != iend; ++i) {
       llvm::Instruction &I = *i;
@@ -3023,6 +3251,30 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
           // External function validation will check the parameter
           // list. This function will check that the call does not
           // violate any rules.
+
+          if (dxilOpcode == DXIL::OpCode::SetMeshOutputCounts) {
+            // validate the call count of SetMeshOutputCounts
+            if (setMeshOutputCounts != nullptr) {
+              ValCtx.EmitInstrError(&I, ValidationRule::InstrMultipleSetMeshOutputCounts);
+            }
+            setMeshOutputCounts = CI;
+          }
+
+          if (dxilOpcode == DXIL::OpCode::GetMeshPayload) {
+            // validate the call count of GetMeshPayload
+            if (getMeshPayload != nullptr) {
+              ValCtx.EmitInstrError(&I, ValidationRule::InstrMultipleGetMeshPayload);
+            }
+            getMeshPayload = CI;
+          }
+
+          if (dxilOpcode == DXIL::OpCode::DispatchMesh) {
+            // validate the call count of DispatchMesh
+            if (dispatchMesh != nullptr) {
+              ValCtx.EmitInstrError(&I, ValidationRule::InstrNotOnceDispatchMesh);
+            }
+            dispatchMesh = CI;
+          }
         }
         continue;
       }
@@ -3069,11 +3321,13 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
       unsigned opcode = I.getOpcode();
       switch (opcode) {
       case Instruction::Alloca: {
-        AllocaInst *AI = cast<AllocaInst>(&I);
-        // TODO: validate address space and alignment
-        Type *Ty = AI->getAllocatedType();
-        if (!ValidateType(Ty, ValCtx)) {
-          continue;
+        if (!IsFromMeshPayload(&I)) {
+          AllocaInst *AI = cast<AllocaInst>(&I);
+          // TODO: validate address space and alignment
+          Type *Ty = AI->getAllocatedType();
+          if (!ValidateType(Ty, ValCtx)) {
+            continue;
+          }
         }
       } break;
       case Instruction::ExtractValue: {
@@ -3096,16 +3350,20 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
         }
       } break;
       case Instruction::Store: {
-        StoreInst *SI = cast<StoreInst>(&I);
-        Type *Ty = SI->getValueOperand()->getType();
-        if (!ValidateType(Ty, ValCtx)) {
-          continue;
+        if (!IsFromMeshPayload(&I)) {
+          StoreInst *SI = cast<StoreInst>(&I);
+          Type *Ty = SI->getValueOperand()->getType();
+          if (!ValidateType(Ty, ValCtx)) {
+            continue;
+          }
         }
       } break;
       case Instruction::GetElementPtr: {
-        Type *Ty = I.getType()->getPointerElementType();
-        if (!ValidateType(Ty, ValCtx)) {
-          continue;
+        if (!IsFromMeshPayload(&I)) {
+          Type *Ty = I.getType()->getPointerElementType();
+          if (!ValidateType(Ty, ValCtx)) {
+            continue;
+          }
         }
         GetElementPtrInst *GEP = cast<GetElementPtrInst>(&I);
         bool allImmIndex = true;
@@ -3220,6 +3478,10 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
   if (!gradientOps.empty()) {
     ValidateGradientOps(F, gradientOps, barriers, ValCtx);
   }
+
+  ValidateMsIntrinsics(F, ValCtx, setMeshOutputCounts, getMeshPayload);
+
+  ValidateAsIntrinsics(F, ValCtx, dispatchMesh);
 }
 
 static void ValidateFunction(Function &F, ValidationContext &ValCtx) {
@@ -3458,10 +3720,16 @@ static void ValidateGlobalVariables(ValidationContext &ValCtx) {
     }
   }
 
-  if (TGSMSize > DXIL::kMaxTGSMSize) {
+  if (M.GetShaderModel()->IsMS()) {
+    if (TGSMSize > DXIL::kMaxMSSMSize) {
+      ValCtx.EmitFormatError(ValidationRule::SmMaxMSSMSize,
+                             { std::to_string(TGSMSize),
+                               std::to_string(DXIL::kMaxMSSMSize) });
+    }
+  } else if (TGSMSize > DXIL::kMaxTGSMSize) {
     ValCtx.EmitFormatError(ValidationRule::SmMaxTGSMSize,
-                           {std::to_string(TGSMSize),
-                            std::to_string(DXIL::kMaxTGSMSize)});
+                           { std::to_string(TGSMSize),
+                             std::to_string(DXIL::kMaxTGSMSize) });
   }
   if (!fixAddrTGSMList.empty()) {
     ValidateTGSMRaceCondition(fixAddrTGSMList, ValCtx);
@@ -3643,6 +3911,11 @@ static void ValidateResource(hlsl::DxilResource &res,
   case DXIL::ResourceKind::RTAccelerationStructure:
     // TODO: check profile.
     break;
+  case DXIL::ResourceKind::FeedbackTexture2DMinLOD:
+  case DXIL::ResourceKind::FeedbackTexture2DTiled:
+  case DXIL::ResourceKind::FeedbackTexture2DArrayMinLOD:
+  case DXIL::ResourceKind::FeedbackTexture2DArrayTiled:
+    break;
   default:
     ValCtx.EmitResourceError(&res, ValidationRule::SmInvalidResourceKind);
     break;
@@ -4408,6 +4681,14 @@ static void ValidateNoInterpModeSignature(ValidationContext &ValCtx, const DxilS
   }
 }
 
+static void ValidateConstantInterpModeSignature(ValidationContext &ValCtx, const DxilSignature &S) {
+  for (auto &E : S.GetElements()) {
+    if (!E->GetInterpolationMode()->IsConstant()) {
+      ValCtx.EmitSignatureError(E.get(), ValidationRule::SmConstantInterpMode);
+    }
+  }
+}
+
 static void ValidateEntrySignatures(ValidationContext &ValCtx,
                                     const DxilEntryProps &entryProps,
                                     EntryStatus &Status,
@@ -4419,7 +4700,7 @@ static void ValidateEntrySignatures(ValidationContext &ValCtx,
     // No signatures allowed
     if (!S.InputSignature.GetElements().empty() ||
         !S.OutputSignature.GetElements().empty() ||
-        !S.PatchConstantSignature.GetElements().empty()) {
+        !S.PatchConstOrPrimSignature.GetElements().empty()) {
       ValCtx.EmitFormatError(ValidationRule::SmRayShaderSignatures, { F.getName() });
     }
 
@@ -4465,6 +4746,7 @@ static void ValidateEntrySignatures(ValidationContext &ValCtx,
   bool isVS = props.IsVS();
   bool isGS = props.IsGS();
   bool isCS = props.IsCS();
+  bool isMS = props.IsMS();
 
   if (isPS) {
     // PS output no interp mode.
@@ -4473,8 +4755,14 @@ static void ValidateEntrySignatures(ValidationContext &ValCtx,
     // VS input no interp mode.
     ValidateNoInterpModeSignature(ValCtx, S.InputSignature);
   }
-  // patch constant no interp mode.
-  ValidateNoInterpModeSignature(ValCtx, S.PatchConstantSignature);
+
+  if (isMS) {
+    // primitive output constant interp mode.
+    ValidateConstantInterpModeSignature(ValCtx, S.PatchConstOrPrimSignature);
+  } else {
+    // patch constant no interp mode.
+    ValidateNoInterpModeSignature(ValCtx, S.PatchConstOrPrimSignature);
+  }
 
   unsigned maxInputScalars = DXIL::kMaxInputTotalScalars;
   unsigned maxOutputScalars = 0;
@@ -4493,13 +4781,18 @@ static void ValidateEntrySignatures(ValidationContext &ValCtx,
       maxOutputScalars = DXIL::kMaxOutputTotalScalars;
       maxPatchConstantScalars = DXIL::kMaxHSOutputPatchConstantTotalScalars;
     break;
+  case DXIL::ShaderKind::Mesh:
+    maxOutputScalars = DXIL::kMaxOutputTotalScalars;
+    maxPatchConstantScalars = DXIL::kMaxOutputTotalScalars;
+    break;
+  case DXIL::ShaderKind::Amplification:
   default:
     break;
   }
 
   ValidateSignature(ValCtx, S.InputSignature, Status, maxInputScalars);
   ValidateSignature(ValCtx, S.OutputSignature, Status, maxOutputScalars);
-  ValidateSignature(ValCtx, S.PatchConstantSignature, Status,
+  ValidateSignature(ValCtx, S.PatchConstOrPrimSignature, Status,
                     maxPatchConstantScalars);
 
   if (isPS) {
@@ -4579,10 +4872,53 @@ static void ValidateEntrySignatures(ValidationContext &ValCtx,
   if (isCS) {
       if (!S.InputSignature.GetElements().empty() ||
           !S.OutputSignature.GetElements().empty() ||
-          !S.PatchConstantSignature.GetElements().empty()) {
+          !S.PatchConstOrPrimSignature.GetElements().empty()) {
         ValCtx.EmitError(ValidationRule::SmCSNoSignatures);
       }
   }
+
+  if (isMS) {
+    unsigned VertexSignatureRows = S.OutputSignature.GetRowCount();
+    if (VertexSignatureRows > DXIL::kMaxMSVSigRows) {
+      ValCtx.EmitFormatError(
+        ValidationRule::SmMeshVSigRowCount,
+        { F.getName(), std::to_string(DXIL::kMaxMSVSigRows) });
+    }
+    unsigned PrimitiveSignatureRows = S.PatchConstOrPrimSignature.GetRowCount();
+    if (PrimitiveSignatureRows > DXIL::kMaxMSPSigRows) {
+      ValCtx.EmitFormatError(
+        ValidationRule::SmMeshPSigRowCount,
+        { F.getName(), std::to_string(DXIL::kMaxMSPSigRows) });
+    }
+    if (VertexSignatureRows + PrimitiveSignatureRows > DXIL::kMaxMSTotalSigRows) {
+      ValCtx.EmitFormatError(
+        ValidationRule::SmMeshTotalSigRowCount,
+        { F.getName(), std::to_string(DXIL::kMaxMSTotalSigRows) });
+    }
+
+    unsigned maxVertexCount = props.ShaderProps.MS.maxVertexCount;
+    unsigned maxPrimitiveCount = props.ShaderProps.MS.maxPrimitiveCount;
+    unsigned totalOutputScalars = 0;
+    for (auto &SE : S.OutputSignature.GetElements()) {
+      totalOutputScalars += SE->GetRows() * SE->GetCols() * maxVertexCount;
+    }
+    for (auto &SE : S.PatchConstOrPrimSignature.GetElements()) {
+      totalOutputScalars += SE->GetRows() * SE->GetCols() * maxPrimitiveCount;
+    }
+
+    if (totalOutputScalars > DXIL::kMaxMSOutputTotalScalars) {
+      ValCtx.EmitFormatError(
+        ValidationRule::SmMeshShaderOutputSize,
+        { F.getName(), std::to_string(DXIL::kMaxMSOutputTotalScalars) });
+    }
+
+    unsigned totalInputOutputScalars = totalOutputScalars + props.ShaderProps.MS.payloadByteSize;
+    if (totalInputOutputScalars > DXIL::kMaxMSInputOutputTotalScalars) {
+      ValCtx.EmitFormatError(
+        ValidationRule::SmMeshShaderInOutSize,
+        { F.getName(), std::to_string(DXIL::kMaxMSInputOutputTotalScalars) });
+    }
+  }
 }
 
 static void ValidateEntrySignatures(ValidationContext &ValCtx) {
@@ -4617,7 +4953,7 @@ static void CheckPatchConstantSemantic(ValidationContext &ValCtx,
   DXIL::TessellatorDomain domain =
       isHS ? props.ShaderProps.HS.domain : props.ShaderProps.DS.domain;
 
-  const DxilSignature &patchConstantSig = EntryProps.sig.PatchConstantSignature;
+  const DxilSignature &patchConstantSig = EntryProps.sig.PatchConstOrPrimSignature;
 
   const unsigned kQuadEdgeSize = 4;
   const unsigned kQuadInsideSize = 2;
@@ -4765,6 +5101,92 @@ static void ValidateEntryProps(ValidationContext &ValCtx,
                               std::to_string(DXIL::kMaxCSThreadsPerGroup)});
     }
 
+    // type of threadID, thread group ID take care by DXIL operation overload
+    // check.
+  } else if (ShaderType == DXIL::ShaderKind::Mesh) {
+    const auto &MS = props.ShaderProps.MS;
+    unsigned x = MS.numThreads[0];
+    unsigned y = MS.numThreads[1];
+    unsigned z = MS.numThreads[2];
+
+    unsigned threadsInGroup = x * y * z;
+
+    if ((x < DXIL::kMinMSASThreadGroupX) || (x > DXIL::kMaxMSASThreadGroupX)) {
+      ValCtx.EmitFormatError(ValidationRule::SmThreadGroupChannelRange,
+                             {"X", std::to_string(x),
+                              std::to_string(DXIL::kMinMSASThreadGroupX),
+                              std::to_string(DXIL::kMaxMSASThreadGroupX)});
+    }
+    if ((y < DXIL::kMinMSASThreadGroupY) || (y > DXIL::kMaxMSASThreadGroupY)) {
+      ValCtx.EmitFormatError(ValidationRule::SmThreadGroupChannelRange,
+                             {"Y", std::to_string(y),
+                              std::to_string(DXIL::kMinMSASThreadGroupY),
+                              std::to_string(DXIL::kMaxMSASThreadGroupY)});
+    }
+    if ((z < DXIL::kMinMSASThreadGroupZ) || (z > DXIL::kMaxMSASThreadGroupZ)) {
+      ValCtx.EmitFormatError(ValidationRule::SmThreadGroupChannelRange,
+                             {"Z", std::to_string(z),
+                              std::to_string(DXIL::kMinMSASThreadGroupZ),
+                              std::to_string(DXIL::kMaxMSASThreadGroupZ)});
+    }
+
+    if (threadsInGroup > DXIL::kMaxMSASThreadsPerGroup) {
+      ValCtx.EmitFormatError(ValidationRule::SmMaxTheadGroup,
+                             {std::to_string(threadsInGroup),
+                              std::to_string(DXIL::kMaxMSASThreadsPerGroup)});
+    }
+
+    // type of threadID, thread group ID take care by DXIL operation overload
+    // check.
+
+    unsigned maxVertexCount = MS.maxVertexCount;
+    if (maxVertexCount > DXIL::kMaxMSOutputVertexCount) {
+      ValCtx.EmitFormatError(
+        ValidationRule::SmMeshShaderMaxVertexCount,
+          { std::to_string(DXIL::kMaxMSOutputVertexCount),
+            std::to_string(maxVertexCount) });
+    }
+
+    unsigned maxPrimitiveCount = MS.maxPrimitiveCount;
+    if (maxPrimitiveCount > DXIL::kMaxMSOutputPrimitiveCount) {
+      ValCtx.EmitFormatError(
+        ValidationRule::SmMeshShaderMaxPrimitiveCount,
+          { std::to_string(DXIL::kMaxMSOutputPrimitiveCount),
+            std::to_string(maxPrimitiveCount) });
+    }
+  } else if (ShaderType == DXIL::ShaderKind::Amplification) {
+    const auto &AS = props.ShaderProps.AS;
+    unsigned x = AS.numThreads[0];
+    unsigned y = AS.numThreads[1];
+    unsigned z = AS.numThreads[2];
+
+    unsigned threadsInGroup = x * y * z;
+
+    if ((x < DXIL::kMinMSASThreadGroupX) || (x > DXIL::kMaxMSASThreadGroupX)) {
+      ValCtx.EmitFormatError(ValidationRule::SmThreadGroupChannelRange,
+                             {"X", std::to_string(x),
+                              std::to_string(DXIL::kMinMSASThreadGroupX),
+                              std::to_string(DXIL::kMaxMSASThreadGroupX)});
+    }
+    if ((y < DXIL::kMinMSASThreadGroupY) || (y > DXIL::kMaxMSASThreadGroupY)) {
+      ValCtx.EmitFormatError(ValidationRule::SmThreadGroupChannelRange,
+                             {"Y", std::to_string(y),
+                              std::to_string(DXIL::kMinMSASThreadGroupY),
+                              std::to_string(DXIL::kMaxMSASThreadGroupY)});
+    }
+    if ((z < DXIL::kMinMSASThreadGroupZ) || (z > DXIL::kMaxMSASThreadGroupZ)) {
+      ValCtx.EmitFormatError(ValidationRule::SmThreadGroupChannelRange,
+                             {"Z", std::to_string(z),
+                              std::to_string(DXIL::kMinMSASThreadGroupZ),
+                              std::to_string(DXIL::kMaxMSASThreadGroupZ)});
+    }
+
+    if (threadsInGroup > DXIL::kMaxMSASThreadsPerGroup) {
+      ValCtx.EmitFormatError(ValidationRule::SmMaxTheadGroup,
+                             {std::to_string(threadsInGroup),
+                              std::to_string(DXIL::kMaxMSASThreadsPerGroup)});
+    }
+
     // type of threadID, thread group ID take care by DXIL operation overload
     // check.
   } else if (ShaderType == DXIL::ShaderKind::Domain) {
@@ -5023,10 +5445,10 @@ static void ValidateUninitializedOutput(ValidationContext &ValCtx,
   const DxilFunctionProps &props = entryProps.props;
   // For HS only need to check Tessfactor which is in patch constant sig.
   if (props.IsHS()) {
-    std::vector<unsigned> &patchConstCols = Status.patchConstCols;
-    const DxilSignature &patchConstSig = entryProps.sig.PatchConstantSignature;
+    std::vector<unsigned> &patchConstOrPrimCols = Status.patchConstOrPrimCols;
+    const DxilSignature &patchConstSig = entryProps.sig.PatchConstOrPrimSignature;
     for (auto &E : patchConstSig.GetElements()) {
-      unsigned mask = patchConstCols[E->GetID()];
+      unsigned mask = patchConstOrPrimCols[E->GetID()];
       unsigned requireMask = (1 << E->GetCols()) - 1;
       // TODO: check other case uninitialized output is allowed.
       if (mask != requireMask && !E->GetSemantic()->IsArbitrary()) {
@@ -5214,8 +5636,8 @@ static void VerifySignatureMatches(_In_ ValidationContext &ValCtx,
   case hlsl::DXIL::SignatureKind::Output:
     pName = "Program Output Signature";
     break;
-  case hlsl::DXIL::SignatureKind::PatchConstant:
-    pName = "Program Patch Constant Signature";
+  case hlsl::DXIL::SignatureKind::PatchConstOrPrim:
+    pName = "Program Patch Constant or Primitive Signature";
     break;
   default:
     break;
@@ -5360,7 +5782,9 @@ HRESULT ValidateDxilContainerParts(llvm::Module *pModule,
   ValidationContext ValCtx(*pModule, pDebugModule, *pDxilModule, DiagPrinter);
 
   DXIL::ShaderKind ShaderKind = pDxilModule->GetShaderModel()->GetKind();
-  bool bTess = ShaderKind == DXIL::ShaderKind::Hull || ShaderKind == DXIL::ShaderKind::Domain;
+  bool bTessOrMesh = ShaderKind == DXIL::ShaderKind::Hull ||
+                     ShaderKind == DXIL::ShaderKind::Domain ||
+                     ShaderKind == DXIL::ShaderKind::Mesh;
 
   std::unordered_set<uint32_t> FourCCFound;
   const DxilPartHeader *pRootSignaturePart = nullptr;
@@ -5398,8 +5822,8 @@ HRESULT ValidateDxilContainerParts(llvm::Module *pModule,
       if (ValCtx.isLibProfile) {
         ValCtx.EmitFormatError(ValidationRule::ContainerPartInvalid, { szFourCC });
       } else {
-        if (bTess) {
-          VerifySignatureMatches(ValCtx, DXIL::SignatureKind::PatchConstant, GetDxilPartData(pPart), pPart->PartSize);
+        if (bTessOrMesh) {
+          VerifySignatureMatches(ValCtx, DXIL::SignatureKind::PatchConstOrPrim, GetDxilPartData(pPart), pPart->PartSize);
         } else {
           ValCtx.EmitFormatError(ValidationRule::ContainerPartMatches, {"Program Patch Constant Signature"});
         }
@@ -5466,8 +5890,8 @@ HRESULT ValidateDxilContainerParts(llvm::Module *pModule,
     if (FourCCFound.find(DFCC_OutputSignature) == FourCCFound.end()) {
       VerifySignatureMatches(ValCtx, DXIL::SignatureKind::Output, nullptr, 0);
     }
-    if (bTess && FourCCFound.find(DFCC_PatchConstantSignature) == FourCCFound.end() &&
-        pDxilModule->GetPatchConstantSignature().GetElements().size())
+    if (bTessOrMesh && FourCCFound.find(DFCC_PatchConstantSignature) == FourCCFound.end() &&
+        pDxilModule->GetPatchConstOrPrimSignature().GetElements().size())
     {
       ValCtx.EmitFormatError(ValidationRule::ContainerPartMissing, { "Program Patch Constant Signature" });
     }

+ 2 - 0
lib/HLSL/HLModule.cpp

@@ -831,6 +831,8 @@ void HLModule::GetParameterRowsAndCols(Type *Ty, unsigned &rows, unsigned &cols,
   bool skipOneLevelArray = inputQual == DxilParamInputQual::InputPatch;
   skipOneLevelArray |= inputQual == DxilParamInputQual::OutputPatch;
   skipOneLevelArray |= inputQual == DxilParamInputQual::InputPrimitive;
+  skipOneLevelArray |= inputQual == DxilParamInputQual::OutVertices;
+  skipOneLevelArray |= inputQual == DxilParamInputQual::OutPrimitives;
 
   if (skipOneLevelArray) {
     if (Ty->isArrayTy())

+ 348 - 2
lib/HLSL/HLOperationLower.cpp

@@ -2368,6 +2368,36 @@ Value *TranslateFaceforward(CallInst *CI, IntrinsicOp IOP, OP::OpCode op,
   Value *faceforward = Builder.CreateSelect(dotLtZero, n, negN);
   return faceforward;
 }
+
+Value *TrivialSetMeshOutputCounts(CallInst *CI, IntrinsicOp IOP, OP::OpCode op,
+  HLOperationLowerHelper &helper, HLObjectOperationLowerHelper *pObjHelper, bool &Translated) {
+  hlsl::OP *hlslOP = &helper.hlslOP;
+  Value *src0 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
+  Value *src1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
+  IRBuilder<> Builder(CI);
+  Constant *opArg = hlslOP->GetU32Const((unsigned)op);
+  Value *args[] = { opArg, src0, src1 };
+  Function *dxilFunc = hlslOP->GetOpFunc(op, Type::getVoidTy(CI->getContext()));
+
+  Builder.CreateCall(dxilFunc, args);
+  return nullptr;
+}
+
+Value *TrivialDispatchMesh(CallInst *CI, IntrinsicOp IOP, OP::OpCode op,
+  HLOperationLowerHelper &helper, HLObjectOperationLowerHelper *pObjHelper, bool &Translated) {
+  hlsl::OP *hlslOP = &helper.hlslOP;
+  Value *src0 = CI->getArgOperand(HLOperandIndex::kDispatchMeshOpThreadX);
+  Value *src1 = CI->getArgOperand(HLOperandIndex::kDispatchMeshOpThreadY);
+  Value *src2 = CI->getArgOperand(HLOperandIndex::kDispatchMeshOpThreadZ);
+  Value *src3 = CI->getArgOperand(HLOperandIndex::kDispatchMeshOpPayload);
+  IRBuilder<> Builder(CI);
+  Constant *opArg = hlslOP->GetU32Const((unsigned)op);
+  Value *args[] = { opArg, src0, src1, src2, src3 };
+  Function *dxilFunc = hlslOP->GetOpFunc(op, src3->getType());
+
+  Builder.CreateCall(dxilFunc, args);
+  return nullptr;
+}
 }
 
 // MOP intrinsics
@@ -3214,6 +3244,67 @@ Value *TranslateGather(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode,
   return nullptr;
 }
 
+static Value* TranslateWriteSamplerFeedback(CallInst* CI, IntrinsicOp IOP, OP::OpCode opcode,
+  HLOperationLowerHelper& helper, HLObjectOperationLowerHelper* pObjHelper, bool& Translated) {
+
+  hlsl::OP* hlslOP = &helper.hlslOP;
+
+  IRBuilder<> Builder(CI);
+
+  // Build the DXIL operands
+  SmallVector<Value*, 8> DxilOperands;
+  DxilOperands.emplace_back(Builder.getInt32((unsigned)opcode));
+  DxilOperands.emplace_back(CI->getArgOperand(HLOperandIndex::kHandleOpIdx));
+  DxilOperands.emplace_back(CI->getArgOperand(HLOperandIndex::kWriteSamplerFeedbackSampledArgIndex));
+  DxilOperands.emplace_back(CI->getArgOperand(HLOperandIndex::kWriteSamplerFeedbackSamplerArgIndex));
+
+  // Coords operands
+  constexpr unsigned NumDxilCoordsOperands = 3;
+  Value* CoordsVec = CI->getArgOperand(HLOperandIndex::kWriteSamplerFeedbackCoordArgIndex);
+  VectorType *CoordsVecTy = cast<VectorType>(CoordsVec->getType());
+  DXASSERT_NOMSG(CoordsVecTy->getNumElements() <= NumDxilCoordsOperands);
+  for (unsigned i = 0; i < NumDxilCoordsOperands; ++i) {
+    Value *Coord = i < CoordsVecTy->getNumElements()
+      ? Builder.CreateExtractElement(CoordsVec, i)
+      : UndefValue::get(CoordsVecTy->getElementType());
+    DxilOperands.emplace_back(Coord);
+  }
+
+  unsigned LastHLOperandRead = HLOperandIndex::kWriteSamplerFeedbackCoordArgIndex;
+
+  // Bias/level/grad operands
+  if (opcode == OP::OpCode::WriteSamplerFeedbackBias
+    || opcode == OP::OpCode::WriteSamplerFeedbackLevel) {
+    DxilOperands.emplace_back(CI->getArgOperand(HLOperandIndex::kWriteSamplerFeedbackBiasOrLodArgIndex));
+    LastHLOperandRead = HLOperandIndex::kWriteSamplerFeedbackBiasOrLodArgIndex;
+  }
+  else if (opcode == OP::OpCode::WriteSamplerFeedbackGrad) {
+    DxilOperands.emplace_back(CI->getArgOperand(HLOperandIndex::kWriteSamplerFeedbackDdxArgIndex));
+    DxilOperands.emplace_back(CI->getArgOperand(HLOperandIndex::kWriteSamplerFeedbackDdyArgIndex));
+    LastHLOperandRead = HLOperandIndex::kWriteSamplerFeedbackDdyArgIndex;
+  }
+
+  // Append the optional clamp argument as needed.
+  bool HasOptionalClampOperand = opcode == OP::OpCode::WriteSamplerFeedback
+    || opcode == OP::OpCode::WriteSamplerFeedbackBias
+    || opcode == OP::OpCode::WriteSamplerFeedbackGrad;
+  if (HasOptionalClampOperand) {
+    if (LastHLOperandRead == CI->getNumArgOperands() - 1)
+      DxilOperands.emplace_back(UndefValue::get(Builder.getFloatTy()));
+    else {
+      LastHLOperandRead++;
+      DxilOperands.emplace_back(CI->getArgOperand(LastHLOperandRead));
+    }
+  }
+
+  DXASSERT(LastHLOperandRead == CI->getNumArgOperands() - 1,
+    "Unexpected trailing hlsl intrinsic arguments.");
+
+  // Call the DXIL operation
+  Function* DxilFunc = hlslOP->GetOpFunc(opcode, Builder.getVoidTy());
+  return Builder.CreateCall(DxilFunc, DxilOperands);
+}
+
 // Load/Store intrinsics.
 struct ResLoadHelper {
   ResLoadHelper(CallInst *CI, DxilResource::Kind RK, DxilResourceBase::Class RC,
@@ -4639,6 +4730,206 @@ Value *TranslateTraceRay(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode,
   return Builder.CreateCall(F, Args);
 }
 
+void AllocateRayQueryObjects(llvm::Module *M, HLOperationLowerHelper &helper) {
+  // Iterate functions and insert AllocateRayQuery intrinsic to initialize
+  // handle value for every alloca of ray query type
+  hlsl::OP &hlslOP = helper.hlslOP;
+  Constant *i32Zero = hlslOP.GetI32Const(0);
+  DXIL::OpCode opcode = DXIL::OpCode::AllocateRayQuery;
+  llvm::Value *opcodeVal = hlslOP.GetU32Const(static_cast<unsigned>(opcode));
+  for (Function &f : M->functions()) {
+    if (f.isDeclaration() || f.isIntrinsic() ||
+      GetHLOpcodeGroup(&f) != HLOpcodeGroup::NotHL)
+      continue;
+    // Iterate allocas
+    BasicBlock &BB = f.getEntryBlock();
+    IRBuilder<> Builder(dxilutil::FirstNonAllocaInsertionPt(&BB));
+    for (BasicBlock::iterator BI = BB.begin(), BE = BB.end(); BI != BE;) {
+      // Avoid invalidating the iterator.
+      Instruction *I = BI++;
+      if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) {
+        llvm::Type *allocaTy = AI->getAllocatedType();
+        llvm::Type *elementTy = allocaTy;
+        while (elementTy->isArrayTy())
+          elementTy = elementTy->getArrayElementType();
+        if (dxilutil::IsHLSLRayQueryType(elementTy)) {
+          DxilStructAnnotation *SA = helper.dxilTypeSys.GetStructAnnotation(cast<StructType>(elementTy));
+          DXASSERT(SA, "otherwise, could not find type annoation for RayQuery specialization");
+          DXASSERT(SA->GetNumTemplateArgs() == 1 && SA->GetTemplateArgAnnotation(0).IsIntegral(),
+                   "otherwise, RayQuery has changed, or lacks template args");
+          Builder.SetInsertPoint(AI->getNextNode());
+          DXASSERT(!allocaTy->isArrayTy(), "Array not handled yet");
+          llvm::Function *AllocFn = hlslOP.GetOpFunc(DXIL::OpCode::AllocateRayQuery, Builder.getVoidTy());
+          llvm::Value *rayFlags = ConstantInt::get(helper.i32Ty,
+            APInt(32, SA->GetTemplateArgAnnotation(0).GetIntegral()));
+          llvm::CallInst *CI = Builder.CreateCall(AllocFn, {opcodeVal, rayFlags}, "hRayQuery");
+          llvm::Value *GEP = Builder.CreateGEP(AI, {i32Zero, i32Zero});
+          Builder.CreateStore(CI, GEP);
+        }
+      }
+    }
+  }
+}
+
+static Value* TranslateThisPointerToi32Handle(CallInst*CI, hlsl::OP *hlslOP)
+{
+  IRBuilder<> Builder(CI);
+  Value *thisArg = CI->getArgOperand(1);
+  Constant *i32Zero = hlslOP->GetI32Const(0);
+  Value *handleGEP = Builder.CreateGEP(thisArg, {i32Zero, i32Zero});
+  Value *handleValue = Builder.CreateLoad(handleGEP);
+  return handleValue;
+  }
+
+Value *TranslateTraceRayInline(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode,
+                         HLOperationLowerHelper &helper,
+                         HLObjectOperationLowerHelper *pObjHelper,
+                         bool &Translated) {
+  hlsl::OP *hlslOP = &helper.hlslOP;
+
+  Value *rayDesc = CI->getArgOperand(HLOperandIndex::kTraceRayInlineRayDescOpIdx);
+
+  Value *opArg = hlslOP->GetU32Const(static_cast<unsigned>(opcode));
+
+  Value *Args[DXIL::OperandIndex::kTraceRayInlineNumOp];
+  Args[0] = opArg;
+
+  // Translate this pointer to i32 handle value
+  Args[1] = TranslateThisPointerToi32Handle(CI, hlslOP);
+
+  for (unsigned i = 2; i < HLOperandIndex::kTraceRayInlineRayDescOpIdx; i++) {
+    Args[i] = CI->getArgOperand(i);
+  }
+  // struct RayDesc
+  //{
+  //    float3 Origin;
+  //    float  TMin;
+  //    float3 Direction;
+  //    float  TMax;
+  //};
+  IRBuilder<> Builder(CI);
+  Value *zeroIdx = hlslOP->GetU32Const(0);
+  Value *origin = Builder.CreateGEP(rayDesc, {zeroIdx, zeroIdx});
+  origin = Builder.CreateLoad(origin);
+  unsigned index = DXIL::OperandIndex::kTraceRayInlineRayDescOpIdx;
+  Args[index++] = Builder.CreateExtractElement(origin, (uint64_t)0);
+  Args[index++] = Builder.CreateExtractElement(origin, 1);
+  Args[index++] = Builder.CreateExtractElement(origin, 2);
+
+  Value *tmin = Builder.CreateGEP(rayDesc, {zeroIdx, hlslOP->GetU32Const(1)});
+  tmin = Builder.CreateLoad(tmin);
+  Args[index++] = tmin;
+
+  Value *direction = Builder.CreateGEP(rayDesc, {zeroIdx, hlslOP->GetU32Const(2)});
+  direction = Builder.CreateLoad(direction);
+
+  Args[index++] = Builder.CreateExtractElement(direction, (uint64_t)0);
+  Args[index++] = Builder.CreateExtractElement(direction, 1);
+  Args[index++] = Builder.CreateExtractElement(direction, 2);
+
+  Value *tmax = Builder.CreateGEP(rayDesc, {zeroIdx, hlslOP->GetU32Const(3)});
+  tmax = Builder.CreateLoad(tmax);
+  Args[index++] = tmax;
+
+  DXASSERT_NOMSG(index == DXIL::OperandIndex::kTraceRayInlineNumOp);
+
+  Function *F = hlslOP->GetOpFunc(opcode, Builder.getVoidTy());
+
+  return Builder.CreateCall(F, Args);
+}
+
+Value *TranslateCommitProceduralPrimitiveHit(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode,
+                         HLOperationLowerHelper &helper,
+                         HLObjectOperationLowerHelper *pObjHelper,
+                         bool &Translated) {
+  hlsl::OP *hlslOP = &helper.hlslOP;
+  Value *THit = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
+  Value *opArg = hlslOP->GetU32Const(static_cast<unsigned>(opcode));
+
+  Value *Args[] = {opArg,TranslateThisPointerToi32Handle(CI, hlslOP),THit};
+
+  IRBuilder<> Builder(CI);
+  Function *F = hlslOP->GetOpFunc(opcode, Builder.getVoidTy());
+
+  return Builder.CreateCall(F, Args);
+}
+
+Value *TranslateGenericRayQueryMethod(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode,
+                         HLOperationLowerHelper &helper,
+                         HLObjectOperationLowerHelper *pObjHelper,
+                         bool &Translated) {
+  hlsl::OP *hlslOP = &helper.hlslOP;
+
+  Value *opArg = hlslOP->GetU32Const(static_cast<unsigned>(opcode));
+
+  Value *Args[] = {opArg,TranslateThisPointerToi32Handle(CI, hlslOP)};
+
+  IRBuilder<> Builder(CI);
+  Function *F = hlslOP->GetOpFunc(opcode, CI->getType());
+
+  return Builder.CreateCall(F, Args);
+}
+
+Value *TranslateRayQueryMatrix3x4Operation(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode,
+                         HLOperationLowerHelper &helper,
+                         HLObjectOperationLowerHelper *pObjHelper,
+                         bool &Translated) {
+  hlsl::OP *hlslOP = &helper.hlslOP;
+  VectorType *Ty = cast<VectorType>(CI->getType());
+  Value* handle = TranslateThisPointerToi32Handle(CI, hlslOP);
+  uint32_t rVals[] = {0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2};
+  Constant *rows = ConstantDataVector::get(CI->getContext(), rVals);
+  uint8_t cVals[] = {0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3};
+  Constant *cols = ConstantDataVector::get(CI->getContext(), cVals);
+  Value *retVal =
+      TrivialDxilOperation(opcode, {nullptr, handle, rows, cols}, Ty, CI, hlslOP);
+  return retVal;
+}
+
+Value *TranslateRayQueryTransposedMatrix3x4Operation(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode,
+                                                  HLOperationLowerHelper &helper,
+                                                  HLObjectOperationLowerHelper *pObjHelper,
+                                                  bool &Translated) {
+  hlsl::OP *hlslOP = &helper.hlslOP; 
+  VectorType *Ty = cast<VectorType>(CI->getType());
+  Value* handle = TranslateThisPointerToi32Handle(CI, hlslOP);
+  uint32_t rVals[] = { 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2 };
+  Constant *rows = ConstantDataVector::get(CI->getContext(), rVals);
+  uint8_t cVals[] = { 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3 };
+  Constant *cols = ConstantDataVector::get(CI->getContext(), cVals);
+  Value *retVal =
+      TrivialDxilOperation(opcode, {nullptr, handle, rows, cols}, Ty, CI, hlslOP);
+  return retVal;
+}
+
+Value *TranslateRayQueryFloat2Getter(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode,
+                         HLOperationLowerHelper &helper,
+                         HLObjectOperationLowerHelper *pObjHelper,
+                         bool &Translated) {
+  hlsl::OP *hlslOP = &helper.hlslOP;
+  VectorType *Ty = cast<VectorType>(CI->getType());
+  Value* handle = TranslateThisPointerToi32Handle(CI, hlslOP);
+  uint8_t elementVals[] = {0, 1};
+  Constant *element = ConstantDataVector::get(CI->getContext(), elementVals);
+  Value *retVal =
+      TrivialDxilOperation(opcode, {nullptr, handle, element}, Ty, CI, hlslOP);
+  return retVal;
+}
+
+Value *TranslateRayQueryFloat3Getter(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode,
+                         HLOperationLowerHelper &helper,
+                         HLObjectOperationLowerHelper *pObjHelper,
+                         bool &Translated) {
+  hlsl::OP *hlslOP = &helper.hlslOP;
+  VectorType *Ty = cast<VectorType>(CI->getType());
+  Value* handle = TranslateThisPointerToi32Handle(CI, hlslOP);
+  uint8_t elementVals[] = {0, 1, 2};
+  Constant *element = ConstantDataVector::get(CI->getContext(), elementVals);
+  Value *retVal =
+      TrivialDxilOperation(opcode, {nullptr, handle, element}, Ty, CI, hlslOP);
+  return retVal;
+}
+
 Value *TranslateNoArgVectorOperation(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode,
                          HLOperationLowerHelper &helper,
                          HLObjectOperationLowerHelper *pObjHelper,
@@ -4789,7 +5080,7 @@ Value *StreamOutputLower(CallInst *CI, IntrinsicOp IOP, DXIL::OpCode opcode,
 }
 
 // This table has to match IntrinsicOp orders
-IntrinsicLower gLowerTable[static_cast<unsigned>(IntrinsicOp::Num_Intrinsics)] = {
+IntrinsicLower gLowerTable[] = {
     {IntrinsicOp::IOP_AcceptHitAndEndSearch, TranslateNoArgNoReturnPreserveOutput, DXIL::OpCode::AcceptHitAndEndSearch},
     {IntrinsicOp::IOP_AddUint64,  TranslateAddUint64,  DXIL::OpCode::UAddc},
     {IntrinsicOp::IOP_AllMemoryBarrier, TrivialBarrier, DXIL::OpCode::Barrier},
@@ -4799,11 +5090,13 @@ IntrinsicLower gLowerTable[static_cast<unsigned>(IntrinsicOp::Num_Intrinsics)] =
     {IntrinsicOp::IOP_D3DCOLORtoUBYTE4, TranslateD3DColorToUByte4, DXIL::OpCode::NumOpCodes},
     {IntrinsicOp::IOP_DeviceMemoryBarrier, TrivialBarrier, DXIL::OpCode::Barrier},
     {IntrinsicOp::IOP_DeviceMemoryBarrierWithGroupSync, TrivialBarrier, DXIL::OpCode::Barrier},
+    {IntrinsicOp::IOP_DispatchMesh, TrivialDispatchMesh, DXIL::OpCode::DispatchMesh },
     {IntrinsicOp::IOP_DispatchRaysDimensions, TranslateNoArgVectorOperation, DXIL::OpCode::DispatchRaysDimensions},
     {IntrinsicOp::IOP_DispatchRaysIndex, TranslateNoArgVectorOperation, DXIL::OpCode::DispatchRaysIndex},
     {IntrinsicOp::IOP_EvaluateAttributeAtSample, TranslateEvalSample, DXIL::OpCode::NumOpCodes},
     {IntrinsicOp::IOP_EvaluateAttributeCentroid, TranslateEvalCentroid, DXIL::OpCode::EvalCentroid},
     {IntrinsicOp::IOP_EvaluateAttributeSnapped, TranslateEvalSnapped, DXIL::OpCode::NumOpCodes},
+    {IntrinsicOp::IOP_GeometryIndex, TrivialNoArgWithRetOperation, DXIL::OpCode::GeometryIndex},
     {IntrinsicOp::IOP_GetAttributeAtVertex, TranslateGetAttributeAtVertex, DXIL::OpCode::AttributeAtVertex},
     {IntrinsicOp::IOP_GetRenderTargetSampleCount, TrivialNoArgOperation, DXIL::OpCode::RenderTargetGetSampleCount},
     {IntrinsicOp::IOP_GetRenderTargetSamplePosition, TranslateGetRTSamplePos, DXIL::OpCode::NumOpCodes},
@@ -4847,6 +5140,7 @@ IntrinsicLower gLowerTable[static_cast<unsigned>(IntrinsicOp::Num_Intrinsics)] =
     {IntrinsicOp::IOP_RayTCurrent, TrivialNoArgWithRetOperation, DXIL::OpCode::RayTCurrent},
     {IntrinsicOp::IOP_RayTMin, TrivialNoArgWithRetOperation, DXIL::OpCode::RayTMin},
     {IntrinsicOp::IOP_ReportHit, TranslateReportIntersection, DXIL::OpCode::ReportHit},
+    {IntrinsicOp::IOP_SetMeshOutputCounts, TrivialSetMeshOutputCounts, DXIL::OpCode::SetMeshOutputCounts},
     {IntrinsicOp::IOP_TraceRay, TranslateTraceRay, DXIL::OpCode::TraceRay},
     {IntrinsicOp::IOP_WaveActiveAllEqual, TranslateWaveAllEqual, DXIL::OpCode::WaveActiveAllEqual},
     {IntrinsicOp::IOP_WaveActiveAllTrue, TranslateWaveA2B, DXIL::OpCode::WaveAllTrue},
@@ -5030,6 +5324,49 @@ IntrinsicLower gLowerTable[static_cast<unsigned>(IntrinsicOp::Num_Intrinsics)] =
     {IntrinsicOp::MOP_DecrementCounter, GenerateUpdateCounter, DXIL::OpCode::NumOpCodes},
     {IntrinsicOp::MOP_IncrementCounter, GenerateUpdateCounter, DXIL::OpCode::NumOpCodes},
     {IntrinsicOp::MOP_Consume, EmptyLower, DXIL::OpCode::NumOpCodes},
+    {IntrinsicOp::MOP_WriteSamplerFeedback, TranslateWriteSamplerFeedback, DXIL::OpCode::WriteSamplerFeedback},
+    {IntrinsicOp::MOP_WriteSamplerFeedbackBias, TranslateWriteSamplerFeedback, DXIL::OpCode::WriteSamplerFeedbackBias},
+    {IntrinsicOp::MOP_WriteSamplerFeedbackGrad, TranslateWriteSamplerFeedback, DXIL::OpCode::WriteSamplerFeedbackGrad},
+    {IntrinsicOp::MOP_WriteSamplerFeedbackLevel, TranslateWriteSamplerFeedback, DXIL::OpCode::WriteSamplerFeedbackLevel},
+
+    {IntrinsicOp::MOP_Abort, TranslateGenericRayQueryMethod, DXIL::OpCode::RayQuery_Abort},
+    {IntrinsicOp::MOP_CandidateGeometryIndex, TranslateGenericRayQueryMethod, DXIL::OpCode::RayQuery_CandidateGeometryIndex},
+    {IntrinsicOp::MOP_CandidateInstanceID, TranslateGenericRayQueryMethod, DXIL::OpCode::RayQuery_CandidateInstanceID},
+    {IntrinsicOp::MOP_CandidateInstanceIndex, TranslateGenericRayQueryMethod, DXIL::OpCode::RayQuery_CandidateInstanceIndex},
+    {IntrinsicOp::MOP_CandidateObjectRayDirection, TranslateRayQueryFloat3Getter, DXIL::OpCode::RayQuery_CandidateObjectRayDirection},
+    {IntrinsicOp::MOP_CandidateObjectRayOrigin, TranslateRayQueryFloat3Getter, DXIL::OpCode::RayQuery_CandidateObjectRayOrigin},
+    {IntrinsicOp::MOP_CandidateObjectToWorld3x4, TranslateRayQueryMatrix3x4Operation, DXIL::OpCode::RayQuery_CandidateObjectToWorld3x4},
+    {IntrinsicOp::MOP_CandidateObjectToWorld4x3, TranslateRayQueryTransposedMatrix3x4Operation, DXIL::OpCode::RayQuery_CandidateObjectToWorld3x4},
+    {IntrinsicOp::MOP_CandidatePrimitiveIndex, TranslateGenericRayQueryMethod, DXIL::OpCode::RayQuery_CandidatePrimitiveIndex},
+    {IntrinsicOp::MOP_CandidateProceduralPrimitiveNonOpaque, TranslateGenericRayQueryMethod, DXIL::OpCode::RayQuery_CandidateProceduralPrimitiveNonOpaque},
+    {IntrinsicOp::MOP_CandidateTriangleBarycentrics, TranslateRayQueryFloat2Getter, DXIL::OpCode::RayQuery_CandidateTriangleBarycentrics},
+    {IntrinsicOp::MOP_CandidateTriangleFrontFace, TranslateGenericRayQueryMethod, DXIL::OpCode::RayQuery_CandidateTriangleFrontFace},
+    {IntrinsicOp::MOP_CandidateTriangleRayT, TranslateGenericRayQueryMethod, DXIL::OpCode::RayQuery_CandidateTriangleRayT},
+    {IntrinsicOp::MOP_CandidateType, TranslateGenericRayQueryMethod, DXIL::OpCode::RayQuery_CandidateType},
+    {IntrinsicOp::MOP_CandidateWorldToObject3x4, TranslateRayQueryMatrix3x4Operation, DXIL::OpCode::RayQuery_CandidateWorldToObject3x4},
+    {IntrinsicOp::MOP_CandidateWorldToObject4x3, TranslateRayQueryTransposedMatrix3x4Operation, DXIL::OpCode::RayQuery_CandidateWorldToObject3x4},
+    {IntrinsicOp::MOP_CommitNonOpaqueTriangleHit, TranslateGenericRayQueryMethod, DXIL::OpCode::RayQuery_CommitNonOpaqueTriangleHit},
+    {IntrinsicOp::MOP_CommitProceduralPrimitiveHit, TranslateCommitProceduralPrimitiveHit, DXIL::OpCode::RayQuery_CommitProceduralPrimitiveHit},
+    {IntrinsicOp::MOP_CommittedGeometryIndex, TranslateGenericRayQueryMethod, DXIL::OpCode::RayQuery_CommittedGeometryIndex},
+    {IntrinsicOp::MOP_CommittedInstanceID, TranslateGenericRayQueryMethod, DXIL::OpCode::RayQuery_CommittedInstanceID},
+    {IntrinsicOp::MOP_CommittedInstanceIndex, TranslateGenericRayQueryMethod, DXIL::OpCode::RayQuery_CommittedInstanceIndex},
+    {IntrinsicOp::MOP_CommittedObjectRayDirection, TranslateRayQueryFloat3Getter, DXIL::OpCode::RayQuery_CommittedObjectRayDirection},
+    {IntrinsicOp::MOP_CommittedObjectRayOrigin, TranslateRayQueryFloat3Getter, DXIL::OpCode::RayQuery_CommittedObjectRayOrigin},
+    {IntrinsicOp::MOP_CommittedObjectToWorld3x4, TranslateRayQueryMatrix3x4Operation, DXIL::OpCode::RayQuery_CommittedObjectToWorld3x4},
+    {IntrinsicOp::MOP_CommittedObjectToWorld4x3, TranslateRayQueryTransposedMatrix3x4Operation, DXIL::OpCode::RayQuery_CommittedObjectToWorld3x4},
+    {IntrinsicOp::MOP_CommittedPrimitiveIndex, TranslateGenericRayQueryMethod, DXIL::OpCode::RayQuery_CommittedPrimitiveIndex},
+    {IntrinsicOp::MOP_CommittedRayT, TranslateGenericRayQueryMethod, DXIL::OpCode::RayQuery_CommittedRayT},
+    {IntrinsicOp::MOP_CommittedStatus, TranslateGenericRayQueryMethod, DXIL::OpCode::RayQuery_CommittedStatus},
+    {IntrinsicOp::MOP_CommittedTriangleBarycentrics, TranslateRayQueryFloat2Getter, DXIL::OpCode::RayQuery_CommittedTriangleBarycentrics},
+    {IntrinsicOp::MOP_CommittedTriangleFrontFace, TranslateGenericRayQueryMethod, DXIL::OpCode::RayQuery_CommittedTriangleFrontFace},
+    {IntrinsicOp::MOP_CommittedWorldToObject3x4, TranslateRayQueryMatrix3x4Operation, DXIL::OpCode::RayQuery_CommittedWorldToObject3x4},
+    {IntrinsicOp::MOP_CommittedWorldToObject4x3, TranslateRayQueryTransposedMatrix3x4Operation, DXIL::OpCode::RayQuery_CommittedWorldToObject3x4},
+    {IntrinsicOp::MOP_Proceed, TranslateGenericRayQueryMethod, DXIL::OpCode::RayQuery_Proceed},
+    {IntrinsicOp::MOP_RayFlags, TranslateGenericRayQueryMethod, DXIL::OpCode::RayQuery_RayFlags},
+    {IntrinsicOp::MOP_RayTMin, TranslateGenericRayQueryMethod, DXIL::OpCode::RayQuery_RayTMin},
+    {IntrinsicOp::MOP_TraceRayInline,  TranslateTraceRayInline,  DXIL::OpCode::RayQuery_TraceRayInline},
+    {IntrinsicOp::MOP_WorldRayDirection, TranslateRayQueryFloat3Getter, DXIL::OpCode::RayQuery_WorldRayDirection},
+    {IntrinsicOp::MOP_WorldRayOrigin, TranslateRayQueryFloat3Getter, DXIL::OpCode::RayQuery_WorldRayOrigin},
 
     // SPIRV change starts
 #ifdef ENABLE_SPIRV_CODEGEN
@@ -5060,6 +5397,8 @@ IntrinsicLower gLowerTable[static_cast<unsigned>(IntrinsicOp::Num_Intrinsics)] =
     { IntrinsicOp::MOP_InterlockedUMin, TranslateMopAtomicBinaryOperation, DXIL::OpCode::NumOpCodes },
 };
 }
+static_assert(sizeof(gLowerTable) / sizeof(gLowerTable[0]) == static_cast<size_t>(IntrinsicOp::Num_Intrinsics),
+  "Intrinsic lowering table must be updated to account for new intrinsics.");
 
 static void TranslateBuiltinIntrinsic(CallInst *CI,
                                       HLOperationLowerHelper &helper,  HLObjectOperationLowerHelper *pObjHelper, bool &Translated) {
@@ -5771,6 +6110,11 @@ void TranslateCBAddressUserLegacy(Instruction *user, Value *handle,
       }
 
       CI->eraseFromParent();
+    } else if (group == HLOpcodeGroup::HLIntrinsic) {
+      // FIXME: This case is hit when using built-in structures in constant
+      //        buffers passed directly to an intrinsic, such as:
+      //        RayDesc from cbuffer passed to TraceRay.
+      DXASSERT(0, "not implemented yet");
     } else {
       DXASSERT(0, "not implemented yet");
     }
@@ -6260,7 +6604,7 @@ void TranslateStructBufMatSt(Type *matType, IRBuilder<> &Builder, Value *handle,
   for (unsigned i = 0; i < matSize; i++)
     elts[i] = Builder.CreateExtractElement(val, i);
 
-  for (unsigned i = 0; i < matSize; i += 4) {
+  for (unsigned i = 0; i < matSize; i += 4) { 
     uint8_t mask = 0;
     for (unsigned j = 0; j < 4 && (i+j) < matSize; j++) {
       if (elts[i+j] != undefElt)
@@ -7320,6 +7664,8 @@ void TranslateBuiltinOperations(
 
   Module *M = HLM.GetModule();
 
+  AllocateRayQueryObjects(M, helper);
+
   SmallVector<Function *, 4> NonUniformResourceIndexIntrinsics;
 
   // generate dxil operation

+ 208 - 62
lib/HLSL/HLSignatureLower.cpp

@@ -129,6 +129,18 @@ void replaceInputOutputWithIntrinsic(DXIL::SemanticKind semKind, Value *GV,
   case Semantic::Kind::ViewID:
     opcode = OP::OpCode::ViewID;
     break;
+  case Semantic::Kind::GroupThreadID:
+    opcode = OP::OpCode::ThreadIdInGroup;
+    break;
+  case Semantic::Kind::GroupID:
+    opcode = OP::OpCode::GroupId;
+    break;
+  case Semantic::Kind::DispatchThreadID:
+    opcode = OP::OpCode::ThreadId;
+    break;
+  case Semantic::Kind::GroupIndex:
+    opcode = OP::OpCode::FlattenedThreadIdInGroup;
+    break;
   default:
     DXASSERT(0, "invalid semantic");
     return;
@@ -138,19 +150,25 @@ void replaceInputOutputWithIntrinsic(DXIL::SemanticKind semKind, Value *GV,
   Constant *OpArg = hlslOP->GetU32Const((unsigned)opcode);
 
   Value *newArg = nullptr;
-  if (semKind == Semantic::Kind::DomainLocation) {
+  if (semKind == Semantic::Kind::DomainLocation ||
+      semKind == Semantic::Kind::GroupThreadID ||
+      semKind == Semantic::Kind::GroupID ||
+      semKind == Semantic::Kind::DispatchThreadID) {
     unsigned vecSize = 1;
     if (Ty->isVectorTy())
       vecSize = Ty->getVectorNumElements();
 
-    newArg = Builder.CreateCall(dxilFunc, {OpArg, hlslOP->GetU8Const(0)});
+    newArg = Builder.CreateCall(dxilFunc, { OpArg,
+      semKind == Semantic::Kind::DomainLocation ? hlslOP->GetU8Const(0) : hlslOP->GetU32Const(0) });
     if (vecSize > 1) {
       Value *result = UndefValue::get(Ty);
       result = Builder.CreateInsertElement(result, newArg, (uint64_t)0);
 
       for (unsigned i = 1; i < vecSize; i++) {
         Value *newElt =
-            Builder.CreateCall(dxilFunc, {OpArg, hlslOP->GetU8Const(i)});
+            Builder.CreateCall(dxilFunc, { OpArg,
+              semKind == Semantic::Kind::DomainLocation ? hlslOP->GetU8Const(i)
+                                                        : hlslOP->GetU32Const(i) });
         result = Builder.CreateInsertElement(result, newElt, i);
       }
       newArg = result;
@@ -223,7 +241,15 @@ void HLSignatureLower::ProcessArgument(Function *func,
       paramAnnotation.GetInterpolationMode().GetKind();
 
   // Set undefined interpMode.
-  if (!sigPoint->NeedsInterpMode())
+  if (sigPoint->GetKind() == DXIL::SigPointKind::MSPOut) {
+    if (interpMode != InterpolationMode::Kind::Undefined &&
+        interpMode != InterpolationMode::Kind::Constant) {
+      Entry->getContext().emitError(
+        "Mesh shader's primitive outputs' interpolation mode must be constant or undefined.");
+    }
+    interpMode = InterpolationMode::Kind::Constant;
+  }
+  else if (!sigPoint->NeedsInterpMode())
     interpMode = InterpolationMode::Kind::Undefined;
   else if (interpMode == InterpolationMode::Kind::Undefined) {
     // Type-based default: linear for floats, constant for others.
@@ -266,7 +292,7 @@ void HLSignatureLower::ProcessArgument(Function *func,
             ? m_InputSemanticsUsed
             : (sigPoint->IsOutput()
                    ? m_OutputSemanticsUsed[streamIdx]
-                   : (sigPoint->IsPatchConstant() ? m_PatchConstantSemanticsUsed
+                   : (sigPoint->IsPatchConstOrPrim() ? m_PatchConstantSemanticsUsed
                                                   : m_OtherSemanticsUsed));
     if (SemanticUseMap.count((unsigned)pSemantic->GetKind()) > 0) {
       auto &SemanticIndexSet = SemanticUseMap[(unsigned)pSemantic->GetKind()];
@@ -346,8 +372,8 @@ void HLSignatureLower::ProcessArgument(Function *func,
   case DXIL::SignatureKind::Output:
     pSig = &EntrySig.OutputSignature;
     break;
-  case DXIL::SignatureKind::PatchConstant:
-    pSig = &EntrySig.PatchConstantSignature;
+  case DXIL::SignatureKind::PatchConstOrPrim:
+    pSig = &EntrySig.PatchConstOrPrimSignature;
     break;
   default:
     DXASSERT(false, "Expected real signature kind at this point");
@@ -359,7 +385,7 @@ void HLSignatureLower::ProcessArgument(Function *func,
   {
     // Add signature element to appropriate maps
     if (isPatchConstantFunction &&
-        sigKind != DXIL::SignatureKind::PatchConstant) {
+        sigKind != DXIL::SignatureKind::PatchConstOrPrim) {
       pSE = FindArgInSignature(arg, paramAnnotation.GetSemanticString(),
                                interpMode, sigPoint->GetKind(), *pSig);
       if (!pSE) {
@@ -415,6 +441,14 @@ void HLSignatureLower::CreateDxilSignatures() {
     if (HLModule::IsStreamOutputPtrType(Ty))
       continue;
 
+    // Skip OutIndices and InPayload
+    DxilParameterAnnotation &paramAnnotation =
+      EntryAnnotation->GetParameterAnnotation(arg.getArgNo());
+    hlsl::DxilParamInputQual qual = paramAnnotation.GetParamInputQual();
+    if (qual == hlsl::DxilParamInputQual::OutIndices ||
+        qual == hlsl::DxilParamInputQual::InPayload)
+      continue;
+
     ProcessArgument(Entry, EntryAnnotation, arg, props, pSM,
                     isPatchConstantFunctionFalse, bForOutFasle, bHasClipPlane);
   }
@@ -463,16 +497,19 @@ void HLSignatureLower::AllocateDxilInputOutputs() {
         "Failed to allocate all input signature elements in available space.");
   }
 
-  hlsl::PackDxilSignature(EntrySig.OutputSignature, packing);
-  if (!EntrySig.OutputSignature.IsFullyAllocated()) {
-    HLM.GetCtx().emitError(
-        "Failed to allocate all output signature elements in available space.");
+  if (props.shaderKind != DXIL::ShaderKind::Amplification) {
+    hlsl::PackDxilSignature(EntrySig.OutputSignature, packing);
+    if (!EntrySig.OutputSignature.IsFullyAllocated()) {
+      HLM.GetCtx().emitError(
+          "Failed to allocate all output signature elements in available space.");
+    }
   }
 
   if (props.shaderKind == DXIL::ShaderKind::Hull ||
-      props.shaderKind == DXIL::ShaderKind::Domain) {
-    hlsl::PackDxilSignature(EntrySig.PatchConstantSignature, packing);
-    if (!EntrySig.PatchConstantSignature.IsFullyAllocated()) {
+      props.shaderKind == DXIL::ShaderKind::Domain ||
+      props.shaderKind == DXIL::ShaderKind::Mesh) {
+    hlsl::PackDxilSignature(EntrySig.PatchConstOrPrimSignature, packing);
+    if (!EntrySig.PatchConstOrPrimSignature.IsFullyAllocated()) {
       HLM.GetCtx().emitError("Failed to allocate all patch constant signature "
                              "elements in available space.");
     }
@@ -494,7 +531,7 @@ void GenerateStOutput(Function *stOutput, MutableArrayRef<Value *> args,
 
 void replaceStWithStOutput(Function *stOutput, StoreInst *stInst,
                            Constant *OpArg, Constant *outputID, Value *idx,
-                           unsigned cols, bool bI1Cast) {
+                           unsigned cols, Value *vertexOrPrimID, bool bI1Cast) {
   IRBuilder<> Builder(stInst);
   Value *val = stInst->getValueOperand();
 
@@ -503,7 +540,9 @@ void replaceStWithStOutput(Function *stOutput, StoreInst *stInst,
     for (unsigned col = 0; col < cols; col++) {
       Value *subVal = Builder.CreateExtractElement(val, col);
       Value *colIdx = Builder.getInt8(col);
-      Value *args[] = {OpArg, outputID, idx, colIdx, subVal};
+      SmallVector<Value *, 4> args = {OpArg, outputID, idx, colIdx, subVal};
+      if (vertexOrPrimID)
+        args.emplace_back(vertexOrPrimID);
       GenerateStOutput(stOutput, args, Builder, bI1Cast);
     }
     // remove stInst
@@ -512,7 +551,9 @@ void replaceStWithStOutput(Function *stOutput, StoreInst *stInst,
     // TODO: support case cols not 1
     DXASSERT(cols == 1, "only support scalar here");
     Value *colIdx = Builder.getInt8(0);
-    Value *args[] = {OpArg, outputID, idx, colIdx, val};
+    SmallVector<Value *, 4> args = {OpArg, outputID, idx, colIdx, val};
+    if (vertexOrPrimID)
+      args.emplace_back(vertexOrPrimID);
     GenerateStOutput(stOutput, args, Builder, bI1Cast);
     // remove stInst
     stInst->eraseFromParent();
@@ -706,22 +747,22 @@ void replaceDirectInputParameter(Value *param, Function *loadInput,
 struct InputOutputAccessInfo {
   // For input output which has only 1 row, idx is 0.
   Value *idx;
-  // VertexID for HS/DS/GS input.
-  Value *vertexID;
+  // VertexID for HS/DS/GS input, MS vertex output. PrimitiveID for MS primitive output
+  Value *vertexOrPrimID;
   // Vector index.
   Value *vectorIdx;
   // Load/Store/LoadMat/StoreMat on input/output.
   Instruction *user;
   InputOutputAccessInfo(Value *index, Instruction *I)
-      : idx(index), vertexID(nullptr), vectorIdx(nullptr), user(I) {}
+      : idx(index), vertexOrPrimID(nullptr), vectorIdx(nullptr), user(I) {}
   InputOutputAccessInfo(Value *index, Instruction *I, Value *ID, Value *vecIdx)
-      : idx(index), vertexID(ID), vectorIdx(vecIdx), user(I) {}
+      : idx(index), vertexOrPrimID(ID), vectorIdx(vecIdx), user(I) {}
 };
 
 void collectInputOutputAccessInfo(
     Value *GV, Constant *constZero,
-    std::vector<InputOutputAccessInfo> &accessInfoList, bool hasVertexID,
-    bool bInput, bool bRowMajor) {
+    std::vector<InputOutputAccessInfo> &accessInfoList, bool hasVertexOrPrimID,
+    bool bInput, bool bRowMajor, bool isMS) {
   // merge GEP use for input output.
   HLModule::MergeGepUse(GV);
   for (auto User = GV->user_begin(); User != GV->user_end();) {
@@ -743,15 +784,15 @@ void collectInputOutputAccessInfo(
       DXASSERT_LOCALVAR(idx, idx->get() == constZero,
                         "only support 0 offset for input pointer");
 
-      Value *vertexID = nullptr;
+      Value *vertexOrPrimID = nullptr;
       Value *vectorIdx = nullptr;
       gep_type_iterator GEPIt = gep_type_begin(GEP), E = gep_type_end(GEP);
 
       // Skip first pointer idx which must be 0.
       GEPIt++;
-      if (hasVertexID) {
-        // Save vertexID.
-        vertexID = GEPIt.getOperand();
+      if (hasVertexOrPrimID) {
+        // Save vertexOrPrimID.
+        vertexOrPrimID = GEPIt.getOperand();
         GEPIt++;
       }
       // Start from first index.
@@ -806,12 +847,12 @@ void collectInputOutputAccessInfo(
         auto GepUserIt = GepUser++;
         if (LoadInst *ldInst = dyn_cast<LoadInst>(*GepUserIt)) {
           if (bInput) {
-            InputOutputAccessInfo info = {idxVal, ldInst, vertexID, vectorIdx};
+            InputOutputAccessInfo info = {idxVal, ldInst, vertexOrPrimID, vectorIdx};
             accessInfoList.push_back(info);
           }
         } else if (StoreInst *stInst = dyn_cast<StoreInst>(*GepUserIt)) {
           if (!bInput) {
-            InputOutputAccessInfo info = {idxVal, stInst, vertexID, vectorIdx};
+            InputOutputAccessInfo info = {idxVal, stInst, vertexOrPrimID, vectorIdx};
             accessInfoList.push_back(info);
           }
         } else if (CallInst *CI = dyn_cast<CallInst>(*GepUserIt)) {
@@ -823,7 +864,7 @@ void collectInputOutputAccessInfo(
                opcode == HLMatLoadStoreOpcode::RowMatLoad)
                   ? bInput
                   : !bInput) {
-            InputOutputAccessInfo info = {idxVal, CI, vertexID, vectorIdx};
+            InputOutputAccessInfo info = {idxVal, CI, vertexOrPrimID, vectorIdx};
             accessInfoList.push_back(info);
           }
         } else {
@@ -842,17 +883,17 @@ void collectInputOutputAccessInfo(
 void GenerateInputOutputUserCall(InputOutputAccessInfo &info, Value *undefVertexIdx,
     Function *ldStFunc, Constant *OpArg, Constant *ID, unsigned cols, bool bI1Cast,
     Constant *columnConsts[],
-    bool bNeedVertexID, bool isArrayTy, bool bInput, bool bIsInout) {
+    bool bNeedVertexOrPrimID, bool isArrayTy, bool bInput, bool bIsInout) {
   Value *idxVal = info.idx;
-  Value *vertexID = undefVertexIdx;
-  if (bNeedVertexID && isArrayTy) {
-    vertexID = info.vertexID;
+  Value *vertexOrPrimID = undefVertexIdx;
+  if (bNeedVertexOrPrimID && isArrayTy) {
+    vertexOrPrimID = info.vertexOrPrimID;
   }
 
   if (LoadInst *ldInst = dyn_cast<LoadInst>(info.user)) {
     SmallVector<Value *, 4> args = {OpArg, ID, idxVal, info.vectorIdx};
-    if (vertexID)
-      args.emplace_back(vertexID);
+    if (vertexOrPrimID)
+      args.emplace_back(vertexOrPrimID);
 
     replaceLdWithLdInput(ldStFunc, ldInst, cols, args, bI1Cast);
   } else if (StoreInst *stInst = dyn_cast<StoreInst>(info.user)) {
@@ -861,7 +902,7 @@ void GenerateInputOutputUserCall(InputOutputAccessInfo &info, Value *undefVertex
     } else {
       if (!info.vectorIdx) {
         replaceStWithStOutput(ldStFunc, stInst, OpArg, ID, idxVal, cols,
-                              bI1Cast);
+                              vertexOrPrimID, bI1Cast);
       } else {
         Value *V = stInst->getValueOperand();
         Type *Ty = V->getType();
@@ -873,7 +914,9 @@ void GenerateInputOutputUserCall(InputOutputAccessInfo &info, Value *undefVertex
           if (ColIdx->getType()->getBitWidth() != 8) {
             ColIdx = Builder.getInt8(ColIdx->getValue().getLimitedValue());
           }
-          Value *args[] = {OpArg, ID, idxVal, ColIdx, V};
+          SmallVector<Value *, 6> args = {OpArg, ID, idxVal, ColIdx, V};
+          if (vertexOrPrimID)
+            args.emplace_back(vertexOrPrimID);
           GenerateStOutput(ldStFunc, args, Builder, bI1Cast);
         } else {
           BasicBlock *BB = stInst->getParent();
@@ -894,7 +937,9 @@ void GenerateInputOutputUserCall(InputOutputAccessInfo &info, Value *undefVertex
 
             ConstantInt *CaseIdx = SwitchBuilder.getInt8(i);
 
-            Value *args[] = {OpArg, ID, idxVal, CaseIdx, V};
+            SmallVector<Value *, 6> args = {OpArg, ID, idxVal, CaseIdx, V};
+            if (vertexOrPrimID)
+              args.emplace_back(vertexOrPrimID);
             GenerateStOutput(ldStFunc, args, CaseBuilder, bI1Cast);
 
             CaseBuilder.CreateBr(EndBB);
@@ -927,8 +972,8 @@ void GenerateInputOutputUserCall(InputOutputAccessInfo &info, Value *undefVertex
           Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
           for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
             SmallVector<Value *, 4> args = { OpArg, ID, rowIdx, columnConsts[r] };
-            if (vertexID)
-              args.emplace_back(vertexID);
+            if (vertexOrPrimID)
+              args.emplace_back(vertexOrPrimID);
 
             Value *input = LocalBuilder.CreateCall(ldStFunc, args);
             unsigned matIdx = MatTy.getColumnMajorIndex(r, c);
@@ -941,8 +986,8 @@ void GenerateInputOutputUserCall(InputOutputAccessInfo &info, Value *undefVertex
           Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
           for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
             SmallVector<Value *, 4> args = { OpArg, ID, rowIdx, columnConsts[c] };
-            if (vertexID)
-              args.emplace_back(vertexID);
+            if (vertexOrPrimID)
+              args.emplace_back(vertexOrPrimID);
 
             Value *input = LocalBuilder.CreateCall(ldStFunc, args);
             unsigned matIdx = MatTy.getRowMajorIndex(r, c);
@@ -1003,20 +1048,39 @@ void GenerateInputOutputUserCall(InputOutputAccessInfo &info, Value *undefVertex
 } // namespace
 
 void HLSignatureLower::GenerateDxilInputs() {
-  GenerateDxilInputsOutputs(/*bInput*/ true);
+  GenerateDxilInputsOutputs(DXIL::SignatureKind::Input);
 }
 
 void HLSignatureLower::GenerateDxilOutputs() {
-  GenerateDxilInputsOutputs(/*bInput*/ false);
+  GenerateDxilInputsOutputs(DXIL::SignatureKind::Output);
 }
 
-void HLSignatureLower::GenerateDxilInputsOutputs(bool bInput) {
+void HLSignatureLower::GenerateDxilPrimOutputs() {
+  GenerateDxilInputsOutputs(DXIL::SignatureKind::PatchConstOrPrim);
+}
+
+void HLSignatureLower::GenerateDxilInputsOutputs(DXIL::SignatureKind SK) {
   OP *hlslOP = HLM.GetOP();
   DxilFunctionProps &props = HLM.GetDxilFunctionProps(Entry);
   Module &M = *(HLM.GetModule());
 
-  OP::OpCode opcode = bInput ? OP::OpCode::LoadInput : OP::OpCode::StoreOutput;
-  bool bNeedVertexID = bInput && (props.IsGS() || props.IsDS() || props.IsHS());
+  OP::OpCode opcode;
+  switch (SK) {
+  case DXIL::SignatureKind::Input:
+    opcode = OP::OpCode::LoadInput;
+    break;
+  case DXIL::SignatureKind::Output:
+    opcode = props.IsMS() ? OP::OpCode::StoreVertexOutput : OP::OpCode::StoreOutput;
+    break;
+  case DXIL::SignatureKind::PatchConstOrPrim:
+    opcode = OP::OpCode::StorePrimitiveOutput;
+    break;
+  default:
+    DXASSERT_NOMSG(0);
+  }
+  bool bInput = SK == DXIL::SignatureKind::Input;
+  bool bNeedVertexOrPrimID = bInput && (props.IsGS() || props.IsDS() || props.IsHS());
+  bNeedVertexOrPrimID |= !bInput && props.IsMS();
 
   Constant *OpArg = hlslOP->GetU32Const((unsigned)opcode);
 
@@ -1030,10 +1094,12 @@ void HLSignatureLower::GenerateDxilInputsOutputs(bool bInput) {
 
   Constant *constZero = hlslOP->GetU32Const(0);
 
-  Value *undefVertexIdx = UndefValue::get(Type::getInt32Ty(HLM.GetCtx()));
+  Value *undefVertexIdx = props.IsMS() || !bInput ? nullptr : UndefValue::get(Type::getInt32Ty(HLM.GetCtx()));
 
   DxilSignature &Sig =
-      bInput ? EntrySig.InputSignature : EntrySig.OutputSignature;
+      bInput ? EntrySig.InputSignature :
+      SK == DXIL::SignatureKind::Output ? EntrySig.OutputSignature :
+      EntrySig.PatchConstOrPrimSignature;
 
   DxilTypeSystem &typeSys = HLM.GetTypeSystem();
   DxilFunctionAnnotation *pFuncAnnot = typeSys.GetFunctionAnnotation(Entry);
@@ -1083,9 +1149,9 @@ void HLSignatureLower::GenerateDxilInputsOutputs(bool bInput) {
 
     if (!GV->getType()->isPointerTy()) {
       DXASSERT(bInput, "direct parameter must be input");
-      Value *vertexID = undefVertexIdx;
+      Value *vertexOrPrimID = undefVertexIdx;
       Value *args[] = {OpArg, ID, /*rowIdx*/ constZero, /*colIdx*/ nullptr,
-                       vertexID};
+                       vertexOrPrimID};
       replaceDirectInputParameter(GV, dxilFunc, cols, args, bI1Cast, hlslOP,
                                   EntryBuilder);
       continue;
@@ -1106,11 +1172,11 @@ void HLSignatureLower::GenerateDxilInputsOutputs(bool bInput) {
     }
     std::vector<InputOutputAccessInfo> accessInfoList;
     collectInputOutputAccessInfo(GV, constZero, accessInfoList,
-                                 bNeedVertexID && bIsArrayTy, bInput, bRowMajor);
+                                 bNeedVertexOrPrimID && bIsArrayTy, bInput, bRowMajor, props.IsMS());
 
     for (InputOutputAccessInfo &info : accessInfoList) {
       GenerateInputOutputUserCall(info, undefVertexIdx, dxilFunc, OpArg, ID,
-                                  cols, bI1Cast, columnConsts, bNeedVertexID,
+                                  cols, bI1Cast, columnConsts, bNeedVertexOrPrimID,
                                   bIsArrayTy, bInput, bIsInout);
     }
   }
@@ -1208,14 +1274,14 @@ void HLSignatureLower::GenerateDxilPatchConstantLdSt() {
   DxilFunctionProps &props = HLM.GetDxilFunctionProps(Entry);
   Module &M = *(HLM.GetModule());
   Constant *constZero = hlslOP->GetU32Const(0);
-  DxilSignature &Sig = EntrySig.PatchConstantSignature;
+  DxilSignature &Sig = EntrySig.PatchConstOrPrimSignature;
   DxilTypeSystem &typeSys = HLM.GetTypeSystem();
   DxilFunctionAnnotation *pFuncAnnot = typeSys.GetFunctionAnnotation(Entry);
   auto InsertPt = Entry->getEntryBlock().getFirstInsertionPt();
   const bool bIsHs = props.IsHS();
   const bool bIsInput = !bIsHs;
   const bool bIsInout = false;
-  const bool bNeedVertexID = false;
+  const bool bNeedVertexOrPrimID = false;
   if (bIsHs) {
     DxilFunctionProps &EntryQual = HLM.GetDxilFunctionProps(Entry);
     Function *patchConstantFunc = EntryQual.ShaderProps.HS.patchConstantFunc;
@@ -1284,8 +1350,8 @@ void HLSignatureLower::GenerateDxilPatchConstantLdSt() {
       }
     }
     std::vector<InputOutputAccessInfo> accessInfoList;
-    collectInputOutputAccessInfo(GV, constZero, accessInfoList, bNeedVertexID,
-                                 bIsInput, bRowMajor);
+    collectInputOutputAccessInfo(GV, constZero, accessInfoList, bNeedVertexOrPrimID,
+                                 bIsInput, bRowMajor, false);
 
     bool bIsArrayTy = GV->getType()->getPointerElementType()->isArrayTy();
     bool isPrecise = m_preciseSigSet.count(SE);
@@ -1294,7 +1360,7 @@ void HLSignatureLower::GenerateDxilPatchConstantLdSt() {
 
     for (InputOutputAccessInfo &info : accessInfoList) {
       GenerateInputOutputUserCall(info, undefVertexIdx, dxilFunc, OpArg, ID,
-                                  cols, bI1Cast, columnConsts, bNeedVertexID,
+                                  cols, bI1Cast, columnConsts, bNeedVertexOrPrimID,
                                   bIsArrayTy, bIsInput, bIsInout);
     }
   }
@@ -1350,12 +1416,12 @@ void HLSignatureLower::GenerateDxilPatchConstantFunctionInputs() {
       }
       std::vector<InputOutputAccessInfo> accessInfoList;
       collectInputOutputAccessInfo(&arg, constZero, accessInfoList,
-                                   /*hasVertexID*/ true, true, bRowMajor);
+                                   /*hasVertexOrPrimID*/ true, true, bRowMajor, false);
       for (InputOutputAccessInfo &info : accessInfoList) {
         if (LoadInst *ldInst = dyn_cast<LoadInst>(info.user)) {
           Constant *OpArg = hlslOP->GetU32Const((unsigned)opcode);
           Value *args[] = {OpArg, inputID, info.idx, info.vectorIdx,
-                           info.vertexID};
+                           info.vertexOrPrimID};
           replaceLdWithLdInput(dxilLdFunc, ldInst, cols, args, bI1Cast);
         } else {
           DXASSERT(0, "input should only be ld");
@@ -1531,10 +1597,87 @@ void HLSignatureLower::GenerateStreamOutputOperations() {
     }
   }
 }
+// Generate DXIL EmitIndices operation.
+void HLSignatureLower::GenerateEmitIndicesOperation(Value *indicesOutput) {
+  OP * hlslOP = HLM.GetOP();
+  Function *DxilFunc = hlslOP->GetOpFunc(OP::OpCode::EmitIndices, Type::getVoidTy(indicesOutput->getContext()));
+  Constant *opArg = hlslOP->GetU32Const((unsigned)OP::OpCode::EmitIndices);
+
+  for (auto U = indicesOutput->user_begin(); U != indicesOutput->user_end();) {
+    Value *user = *(U++);
+    GetElementPtrInst *GEP = cast<GetElementPtrInst>(user);
+    auto idx = GEP->idx_begin();
+    DXASSERT_LOCALVAR(idx, idx->get() == hlslOP->GetU32Const(0),
+                      "only support 0 offset for input pointer");
+    gep_type_iterator GEPIt = gep_type_begin(GEP), E = gep_type_end(GEP);
+
+    // Skip first pointer idx which must be 0.
+    GEPIt++;
+    Value *primIdx = GEPIt.getOperand();
+    DXASSERT(++GEPIt == E, "invalid GEP here");
+
+    auto GepUser = GEP->user_begin();
+    auto GepUserE = GEP->user_end();
+    for (; GepUser != GepUserE;) {
+      auto GepUserIt = GepUser++;
+      StoreInst *stInst = cast<StoreInst>(*GepUserIt);
+      Value *stVal = stInst->getValueOperand();
+      VectorType *VT = cast<VectorType>(stVal->getType());
+      unsigned eleCount = VT->getNumElements();
+      IRBuilder<> Builder(stInst);
+      Value *subVal0 = Builder.CreateExtractElement(stVal, hlslOP->GetU32Const(0));
+      Value *subVal1 = Builder.CreateExtractElement(stVal, hlslOP->GetU32Const(1));
+      Value *subVal2 = eleCount == 3 ?
+        Builder.CreateExtractElement(stVal, hlslOP->GetU32Const(2)) : hlslOP->GetU32Const(0);
+      Value *args[] = { opArg, primIdx, subVal0, subVal1, subVal2 };
+      Builder.CreateCall(DxilFunc, args);
+      stInst->eraseFromParent();
+    }
+    GEP->eraseFromParent();
+  }
+}
+// Generate DXIL EmitIndices operations.
+void HLSignatureLower::GenerateEmitIndicesOperations() {
+  DxilFunctionAnnotation *EntryAnnotation = HLM.GetFunctionAnnotation(Entry);
+  DXASSERT(EntryAnnotation, "must find annotation for entry function");
+
+  for (Argument &arg : Entry->getArgumentList()) {
+    DxilParameterAnnotation &paramAnnotation =
+      EntryAnnotation->GetParameterAnnotation(arg.getArgNo());
+    DxilParamInputQual inputQual = paramAnnotation.GetParamInputQual();
+    if (inputQual == DxilParamInputQual::OutIndices) {
+      GenerateEmitIndicesOperation(&arg);
+    }
+  }
+}
+// Generate DXIL GetMeshPayload operation.
+void HLSignatureLower::GenerateGetMeshPayloadOperation() {
+  DxilFunctionAnnotation *EntryAnnotation = HLM.GetFunctionAnnotation(Entry);
+  DXASSERT(EntryAnnotation, "must find annotation for entry function");
+
+  for (Argument &arg : Entry->getArgumentList()) {
+    DxilParameterAnnotation &paramAnnotation =
+      EntryAnnotation->GetParameterAnnotation(arg.getArgNo());
+    DxilParamInputQual inputQual = paramAnnotation.GetParamInputQual();
+    if (inputQual == DxilParamInputQual::InPayload) {
+      OP * hlslOP = HLM.GetOP();
+      Function *DxilFunc = hlslOP->GetOpFunc(OP::OpCode::GetMeshPayload, arg.getType());
+      Constant *opArg = hlslOP->GetU32Const((unsigned)OP::OpCode::GetMeshPayload);
+      IRBuilder<> Builder(arg.getParent()->getEntryBlock().getFirstInsertionPt());
+      Value *args[] = { opArg };
+      Value *payload = Builder.CreateCall(DxilFunc, args);
+      arg.replaceAllUsesWith(payload);
+    }
+  }
+}
 // Lower signatures.
 void HLSignatureLower::Run() {
   DxilFunctionProps &props = HLM.GetDxilFunctionProps(Entry);
   if (props.IsGraphics()) {
+    if (props.IsMS()) {
+      GenerateEmitIndicesOperations();
+      GenerateGetMeshPayloadOperation();
+    }
     CreateDxilSignatures();
 
     // Allocate input output.
@@ -1542,6 +1685,9 @@ void HLSignatureLower::Run() {
 
     GenerateDxilInputs();
     GenerateDxilOutputs();
+    if (props.IsMS()) {
+      GenerateDxilPrimOutputs();
+    }
   } else if (props.IsCS()) {
     GenerateDxilCSInputs();
   }

+ 9 - 1
lib/HLSL/HLSignatureLower.h

@@ -12,6 +12,7 @@
 #pragma once
 #include <unordered_set>
 #include <unordered_map>
+#include "dxc/DXIL/DxilConstants.h"
 
 namespace llvm {
 class Value;
@@ -49,7 +50,8 @@ private:
   // Generate DXIL input load, output store
   void GenerateDxilInputs();
   void GenerateDxilOutputs();
-  void GenerateDxilInputsOutputs(bool bInput);
+  void GenerateDxilPrimOutputs();
+  void GenerateDxilInputsOutputs(DXIL::SignatureKind SK);
   void GenerateDxilCSInputs();
   void GenerateDxilPatchConstantLdSt();
   void GenerateDxilPatchConstantFunctionInputs();
@@ -59,6 +61,12 @@ private:
   void GenerateStreamOutputOperation(llvm::Value *streamVal, unsigned streamID);
   // Generate DXIL stream output operations.
   void GenerateStreamOutputOperations();
+  // Generate DXIL EmitIndices operation.
+  void GenerateEmitIndicesOperation(llvm::Value *indicesOutput);
+  // Generate DXIL EmitIndices operations.
+  void GenerateEmitIndicesOperations();
+  // Generate DXIL GetMeshPayload operation.
+  void GenerateGetMeshPayloadOperation();
 
 private:
   llvm::Function *Entry;

+ 25 - 8
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -1053,7 +1053,8 @@ void SROA_HLSL::isSafeForScalarRepl(Instruction *I, uint64_t Offset,
         IntrinsicOp opcode = static_cast<IntrinsicOp>(GetHLOpcode(CI));
         if (IntrinsicOp::IOP_TraceRay == opcode ||
             IntrinsicOp::IOP_ReportHit == opcode ||
-            IntrinsicOp::IOP_CallShader == opcode) {
+            IntrinsicOp::IOP_CallShader == opcode ||
+            IntrinsicOp::IOP_DispatchMesh == opcode) {
           return MarkUnsafe(Info, User);
         }
       }
@@ -1586,9 +1587,9 @@ static void SplitCpy(Type *Ty, Value *Dest, Value *Src,
       SimpleCopy(Dest, Src, idxList, Builder);
       return;
     }
+    // Built-in structs have no type annotation
     DxilStructAnnotation *STA = typeSys.GetStructAnnotation(ST);
-    DXASSERT(STA, "require annotation here");
-    if (STA->IsEmptyStruct())
+    if (STA && STA->IsEmptyStruct())
       return;
     for (uint32_t i = 0; i < ST->getNumElements(); i++) {
       llvm::Type *ET = ST->getElementType(i);
@@ -1598,8 +1599,8 @@ static void SplitCpy(Type *Ty, Value *Dest, Value *Src,
       if (bEltMemCpy && IsMemCpyTy(ET, typeSys)) {
         EltMemCpy(ET, Dest, Src, idxList, Builder, DL);
       } else {
-        DxilFieldAnnotation &EltAnnotation = STA->GetFieldAnnotation(i);
-        SplitCpy(ET, Dest, Src, idxList, Builder, DL, typeSys, &EltAnnotation,
+        DxilFieldAnnotation *EltAnnotation = STA ? &STA->GetFieldAnnotation(i) : nullptr;
+        SplitCpy(ET, Dest, Src, idxList, Builder, DL, typeSys, EltAnnotation,
                  bEltMemCpy);
       }
 
@@ -2412,6 +2413,12 @@ void SROA_Helper::RewriteMemIntrin(MemIntrinsic *MI, Value *OldV) {
 }
 
 void SROA_Helper::RewriteBitCast(BitCastInst *BCI) {
+  // Unused bitcast may be leftover from temporary memcpy
+  if (BCI->use_empty()) {
+    BCI->eraseFromParent();
+    return;
+  }
+
   Type *DstTy = BCI->getType();
   Value *Val = BCI->getOperand(0);
   Type *SrcTy = Val->getType();
@@ -2565,6 +2572,13 @@ void SROA_Helper::RewriteCall(CallInst *CI) {
         RewriteCallArg(CI, HLOperandIndex::kBinaryOpSrc1Idx,
                        /*bIn*/ true, /*bOut*/ true);
       } break;
+      case IntrinsicOp::MOP_TraceRayInline: {
+        if (OldVal ==
+            CI->getArgOperand(HLOperandIndex::kTraceRayInlineRayDescOpIdx)) {
+          RewriteCallArg(CI, HLOperandIndex::kTraceRayInlineRayDescOpIdx,
+                         /*bIn*/ true, /*bOut*/ false);
+        }
+      } break;
       default:
         DXASSERT(0, "cannot flatten hlsl intrinsic.");
       }
@@ -2707,8 +2721,8 @@ bool SROA_Helper::DoScalarReplacement(Value *V, std::vector<Value *> &Elts,
   IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Builder.GetInsertPoint()));
 
   if (StructType *ST = dyn_cast<StructType>(Ty)) {
-    // Skip HLSL object types.
-    if (dxilutil::IsHLSLObjectType(ST)) {
+    // Skip HLSL object types and RayQuery.
+    if (dxilutil::IsHLSLObjectType(ST) || dxilutil::IsHLSLRayQueryType(ST)) {
       return false;
     }
 
@@ -4826,9 +4840,12 @@ void SROA_Parameter_HLSL::flattenArgument(
     std::vector<Value *> Elts;
 
     // Not flat vector for entry function currently.
-    bool SROAed = SROA_Helper::DoScalarReplacement(
+    bool SROAed = false;
+    if (inputQual != DxilParamInputQual::InPayload) {
+      SROAed = SROA_Helper::DoScalarReplacement(
         V, Elts, Builder, /*bFlatVector*/ false, annotation.IsPrecise(),
         dxilTypeSys, DL, DeadInsts);
+    }
 
     if (SROAed) {
       Type *Ty = V->getType()->getPointerElementType();

+ 7 - 0
tools/clang/include/clang/AST/HlslTypes.h

@@ -305,6 +305,8 @@ void AddRecordTypeWithHandle(
 void AddRayFlags(clang::ASTContext& context);
 void AddHitKinds(clang::ASTContext& context);
 void AddStateObjectFlags(clang::ASTContext& context);
+void AddCommittedStatus(clang::ASTContext& context);
+void AddCandidateType(clang::ASTContext& context);
 
 /// <summary>Adds the implementation for std::is_equal.</summary>
 void AddStdIsEqualImplementation(clang::ASTContext& context, clang::Sema& sema);
@@ -326,6 +328,11 @@ void AddTemplateTypeWithHandle(
             uint8_t templateArgCount,
   _In_opt_  clang::TypeSourceInfo* defaultTypeArgValue);
 
+void AddRayQueryTemplate(
+           clang::ASTContext& context,
+  _Outptr_ clang::ClassTemplateDecl** typeDecl,
+  _Outptr_ clang::CXXRecordDecl** recordDecl);
+
 /// <summary>Create a function template declaration for the specified method.</summary>
 /// <param name="context">AST context in which to work.</param>
 /// <param name="recordDecl">Class in which the function template is declared.</param>

+ 24 - 0
tools/clang/include/clang/Basic/Attr.td

@@ -871,6 +871,30 @@ def HLSLExport : InheritableAttr {
   let Documentation = [Undocumented];
 }
 
+def HLSLIndices : InheritableAttr {
+  let Spellings = [CXX11<"", "indices", 2015>];
+  let Subjects = SubjectList<[ParmVar]>;
+  let Documentation = [Undocumented];
+}
+
+def HLSLVertices : InheritableAttr {
+  let Spellings = [CXX11<"", "vertices", 2015>];
+  let Subjects = SubjectList<[ParmVar]>;
+  let Documentation = [Undocumented];
+}
+
+def HLSLPrimitives : InheritableAttr {
+  let Spellings = [CXX11<"", "primitives", 2015>];
+  let Subjects = SubjectList<[ParmVar]>;
+  let Documentation = [Undocumented];
+}
+
+def HLSLPayload : InheritableAttr {
+  let Spellings = [CXX11<"", "payload", 2015>];
+  let Subjects = SubjectList<[ParmVar]>;
+  let Documentation = [Undocumented];
+}
+
 // HLSL Change Ends
 
 // SPIRV Change Starts

+ 10 - 4
tools/clang/include/clang/Basic/DiagnosticSemaKinds.td

@@ -7680,14 +7680,20 @@ def err_hlsl_no_struct_user_defined_type: Error<
    "User defined type intrinsic arg must be struct">;
 def err_hlsl_ray_desc_required: Error<
    "Argument type must be struct RayDesc.">;
-def err_hlsl_missing_maxvertexcount_attr: Error<
-   "GS entry point must have the maxvertexcount attribute">;
-def err_hlsl_missing_patchconstantfunc_attr: Error<
-   "HS entry point must have the patchconstantfunc attribute">;
+def err_hlsl_missing_attr: Error<
+   "%0 entry point must have the %1 attribute">;
 def err_hlsl_missing_inout_attr: Error<
    "stream-output object must be an inout parameter">;
 def err_hlsl_unsupported_string_decl: Error<
    "%select{array|parameter|return value}0 of type string is not supported">;
+def err_hlsl_missing_out_attr: Error<
+   "%0 object must be an out parameter">;
+def err_hlsl_missing_in_attr: Error<
+   "%0 object must be an in parameter">;
+def err_hlsl_load_from_mesh_out_arrays: Error<
+   "output arrays of a mesh shader can not be read from">;
+def err_hlsl_out_indices_array_incorrect_access: Error<
+   "a vector in out indices array must be accessed as a whole">;
 // HLSL Change Ends
 
 // SPIRV Change Starts

+ 4 - 0
tools/clang/include/clang/Basic/TokenKinds.def

@@ -513,6 +513,10 @@ KEYWORD(globallycoherent            , KEYHLSL)
 KEYWORD(interface                   , KEYHLSL)
 KEYWORD(sampler_state               , KEYHLSL)
 KEYWORD(technique                   , KEYHLSL)
+KEYWORD(indices                     , KEYHLSL)
+KEYWORD(vertices                    , KEYHLSL)
+KEYWORD(primitives                  , KEYHLSL)
+KEYWORD(payload                     , KEYHLSL)
 ALIAS("Technique", technique        , KEYHLSL)
 ALIAS("technique10", technique      , KEYHLSL)
 ALIAS("technique11", technique      , KEYHLSL)

+ 1 - 0
tools/clang/include/clang/SPIRV/FeatureManager.h

@@ -44,6 +44,7 @@ enum class Extension {
   GOOGLE_hlsl_functionality1,
   GOOGLE_user_type,
   NV_ray_tracing,
+  NV_mesh_shader,
   Unknown,
 };
 

+ 7 - 0
tools/clang/include/clang/SPIRV/SpirvBuilder.h

@@ -518,6 +518,13 @@ public:
   /// \brief Decorates the given target with NoContraction
   void decorateNoContraction(SpirvInstruction *target, SourceLocation);
 
+  /// \brief Decorates the given target with PerPrimitiveNV
+  void decoratePerPrimitiveNV(SpirvInstruction *target, SourceLocation);
+
+  /// \brief Decorates the given target with PerTaskNV
+  void decoratePerTaskNV(SpirvInstruction *target, uint32_t offset,
+                         SourceLocation);
+
   /// --- Constants ---
   /// Each of these methods can acquire a unique constant from the SpirvContext,
   /// and add the context to the list of constants in the module.

+ 4 - 0
tools/clang/include/clang/SPIRV/SpirvContext.h

@@ -228,6 +228,10 @@ public:
     return curShaderModelKind >= ShaderModelKind::RayGeneration &&
            curShaderModelKind <= ShaderModelKind::Callable;
   }
+  bool isMS() const { return curShaderModelKind == ShaderModelKind::Mesh; }
+  bool isAS() const {
+    return curShaderModelKind == ShaderModelKind::Amplification;
+  }
 
 private:
   /// \brief The allocator used to create SPIR-V entity objects.

+ 100 - 0
tools/clang/lib/AST/ASTContextHLSL.cpp

@@ -575,6 +575,35 @@ void hlsl::AddStateObjectFlags(ASTContext& context) {
   AddConstUInt(context, curDC, StringRef("STATE_OBJECT_FLAGS_ALLOW_EXTERNAL_DEPENDENCIES_ON_LOCAL_DEFINITIONS"), (unsigned)DXIL::StateObjectFlags::AllowExternalDependenciesOnLocalDefinitions);
 }
 
+/// <summary> Adds const integers for committed status </summary>
+void hlsl::AddCommittedStatus(ASTContext& context) {
+  DeclContext *curDC = context.getTranslationUnitDecl();
+  // typedef uint COMMITTED_STATUS;
+  IdentifierInfo &enumId = context.Idents.get(StringRef("COMMITTED_STATUS"), tok::TokenKind::identifier);
+  TypeSourceInfo *uintTypeSource = context.getTrivialTypeSourceInfo(context.UnsignedIntTy, NoLoc);
+  TypedefDecl *enumDecl = TypedefDecl::Create(context, curDC, NoLoc, NoLoc, &enumId, uintTypeSource);
+  curDC->addDecl(enumDecl);
+  enumDecl->setImplicit(true);
+  // static const uint COMMITTED_* = *;
+  AddConstUInt(context, curDC, StringRef("COMMITTED_NOTHING"), (unsigned)DXIL::CommittedStatus::CommittedNothing);
+  AddConstUInt(context, curDC, StringRef("COMMITTED_TRIANGLE_HIT"), (unsigned)DXIL::CommittedStatus::CommittedTriangleHit);
+  AddConstUInt(context, curDC, StringRef("COMMITTED_PROCEDURAL_PRIMITIVE_HIT"), (unsigned)DXIL::CommittedStatus::CommittedProceduralPrimitiveHit);
+}
+
+/// <summary> Adds const integers for candidate type </summary>
+void hlsl::AddCandidateType(ASTContext& context) {
+  DeclContext *curDC = context.getTranslationUnitDecl();
+  // typedef uint CANDIDATE_TYPE;
+  IdentifierInfo &enumId = context.Idents.get(StringRef("CANDIDATE_TYPE"), tok::TokenKind::identifier);
+  TypeSourceInfo *uintTypeSource = context.getTrivialTypeSourceInfo(context.UnsignedIntTy, NoLoc);
+  TypedefDecl *enumDecl = TypedefDecl::Create(context, curDC, NoLoc, NoLoc, &enumId, uintTypeSource);
+  curDC->addDecl(enumDecl);
+  enumDecl->setImplicit(true);
+  // static const uint CANDIDATE_* = *;
+  AddConstUInt(context, curDC, StringRef("CANDIDATE_NON_OPAQUE_TRIANGLE"), (unsigned)DXIL::CandidateType::CandidateNonOpaqueTriangle);
+  AddConstUInt(context, curDC, StringRef("CANDIDATE_PROCEDURAL_PRIMITIVE"), (unsigned)DXIL::CandidateType::CandidateProceduralPrimitive);
+}
+
 static
 Expr* IntConstantAsBoolExpr(clang::Sema& sema, uint64_t value)
 {
@@ -959,6 +988,77 @@ CXXMethodDecl* hlsl::CreateObjectFunctionDeclarationWithParams(
   return functionDecl;
 }
 
+void hlsl::AddRayQueryTemplate(
+  ASTContext& context,
+  _Outptr_ ClassTemplateDecl** typeDecl,
+  _Outptr_ CXXRecordDecl** recordDecl
+)
+{
+  DXASSERT_NOMSG(typeDecl != nullptr);
+  DXASSERT_NOMSG(recordDecl != nullptr);
+
+  DeclContext* currentDeclContext = context.getTranslationUnitDecl();
+
+  // Create a RayQuery template declaration in translation unit scope.
+  // template<uint flags> RayQuery { ... }
+  QualType uintType = context.UnsignedIntTy;
+
+  NonTypeTemplateParmDecl* flagsTemplateParamDecl = nullptr;
+  IdentifierInfo& countParamId = context.Idents.get(StringRef("flags"), tok::TokenKind::identifier);
+  flagsTemplateParamDecl = NonTypeTemplateParmDecl::Create(
+    context, currentDeclContext, NoLoc, NoLoc,
+    FirstTemplateDepth, FirstParamPosition, &countParamId, uintType, ParameterPackFalse, nullptr);
+
+  // Should flags default to zero?
+  Expr *literalIntZero = IntegerLiteral::Create(
+    context, llvm::APInt(context.getIntWidth(uintType), 0), uintType, NoLoc);
+  flagsTemplateParamDecl->setDefaultArgument(literalIntZero);
+
+  NamedDecl* templateParameters[] =
+  {
+    flagsTemplateParamDecl
+  };
+  TemplateParameterList* templateParameterList = TemplateParameterList::Create(
+    context, NoLoc, NoLoc, templateParameters, 1, NoLoc);
+
+  IdentifierInfo& typeId = context.Idents.get(StringRef("RayQuery"), tok::TokenKind::identifier);
+  CXXRecordDecl* templateRecordDecl = CXXRecordDecl::Create(
+    context, TagDecl::TagKind::TTK_Class, currentDeclContext, NoLoc, NoLoc, &typeId,
+    nullptr, DelayTypeCreationTrue);
+  ClassTemplateDecl* classTemplateDecl = ClassTemplateDecl::Create(
+    context, currentDeclContext, NoLoc, DeclarationName(&typeId),
+    templateParameterList, templateRecordDecl, nullptr);
+  templateRecordDecl->setDescribedClassTemplate(classTemplateDecl);
+  templateRecordDecl->addAttr(FinalAttr::CreateImplicit(context, FinalAttr::Keyword_final));
+
+  // Requesting the class name specialization will fault in required types.
+  QualType T = classTemplateDecl->getInjectedClassNameSpecialization();
+  T = context.getInjectedClassNameType(templateRecordDecl, T);
+  assert(T->isDependentType() && "Class template type is not dependent?");
+  classTemplateDecl->setLexicalDeclContext(currentDeclContext);
+  templateRecordDecl->setLexicalDeclContext(currentDeclContext);
+  templateRecordDecl->startDefinition();
+
+  // Add an 'h' field to hold the handle.
+  AddHLSLHandleField(context, templateRecordDecl, uintType);
+
+  templateRecordDecl->completeDefinition();
+
+  // Both declarations need to be present for correct handling.
+  currentDeclContext->addDecl(classTemplateDecl);
+  currentDeclContext->addDecl(templateRecordDecl);
+
+#ifdef DBG
+  // Verify that we can read the field member from the template record.
+  DeclContext::lookup_result lookupResult = templateRecordDecl->lookup(
+    DeclarationName(&context.Idents.get(StringRef("h"))));
+  DXASSERT(!lookupResult.empty(), "otherwise template object handle cannot be looked up");
+#endif
+
+  *typeDecl = classTemplateDecl;
+  *recordDecl = templateRecordDecl;
+}
+
 bool hlsl::IsIntrinsicOp(const clang::FunctionDecl *FD) {
   return FD != nullptr && FD->hasAttr<HLSLIntrinsicAttr>();
 }

+ 5 - 0
tools/clang/lib/AST/HlslTypes.cpp

@@ -511,6 +511,11 @@ bool IsHLSLResourceType(clang::QualType type) {
     if (name == "TextureCubeArray" || name == "RWTextureCubeArray")
       return true;
 
+    if (name == "FeedbackTexture2DMinLOD" || name == "FeedbackTexture2DTiled")
+      return true;
+    if (name == "FeedbackTexture2DArrayMinLOD" || name == "FeedbackTexture2DArrayTiled")
+      return true;
+
     if (name == "ByteAddressBuffer" || name == "RWByteAddressBuffer")
       return true;
 

+ 276 - 26
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -545,6 +545,15 @@ StringToTessOutputPrimitive(StringRef primitive) {
   return DXIL::TessellatorOutputPrimitive::Undefined;
 }
 
+static DXIL::MeshOutputTopology
+StringToMeshOutputTopology(StringRef topology) {
+  if (topology == "line")
+    return DXIL::MeshOutputTopology::Line;
+  if (topology == "triangle")
+    return DXIL::MeshOutputTopology::Triangle;
+  return DXIL::MeshOutputTopology::Undefined;
+}
+
 static unsigned RoundToAlign(unsigned num, unsigned mod) {
   // round num to next highest mod
   if (mod != 0)
@@ -859,6 +868,27 @@ unsigned CGMSHLSLRuntime::ConstructStructAnnotation(DxilStructAnnotation *annota
   unsigned offset = 0;
   bool bDefaultRowMajor = m_pHLModule->GetHLOptions().bDefaultRowMajor;
   if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+
+    // If template, save template args
+    if (const ClassTemplateSpecializationDecl *templateSpecializationDecl =
+          dyn_cast<ClassTemplateSpecializationDecl>(CXXRD)) {
+      const clang::TemplateArgumentList &args = templateSpecializationDecl->getTemplateInstantiationArgs();
+      for (unsigned i = 0; i < args.size(); ++i) {
+        DxilTemplateArgAnnotation &argAnnotation = annotation->GetTemplateArgAnnotation(i);
+        const clang::TemplateArgument &arg = args[i];
+        switch (arg.getKind()) {
+        case clang::TemplateArgument::ArgKind::Type:
+          argAnnotation.SetType(CGM.getTypes().ConvertType(arg.getAsType()));
+        break;
+        case clang::TemplateArgument::ArgKind::Integral:
+          argAnnotation.SetIntegral(arg.getAsIntegral().getExtValue());
+          break;
+        default:
+          break;
+        }
+      }
+    }
+
     if (CXXRD->getNumBases()) {
       // Add base as field.
       for (const auto &I : CXXRD->bases()) {
@@ -965,6 +995,17 @@ static bool IsElementInputOutputType(QualType Ty) {
   return Ty->isBuiltinType() || hlsl::IsHLSLVecMatType(Ty) || Ty->isEnumeralType();
 }
 
+static unsigned GetNumTemplateArgsForRecordDecl(const RecordDecl *RD) {
+  if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+    if (const ClassTemplateSpecializationDecl *templateSpecializationDecl =
+          dyn_cast<ClassTemplateSpecializationDecl>(CXXRD)) {
+      const clang::TemplateArgumentList &args = templateSpecializationDecl->getTemplateInstantiationArgs();
+      return args.size();
+    }
+  }
+  return 0;
+}
+
 // Return the size for constant buffer of each decl.
 unsigned CGMSHLSLRuntime::AddTypeAnnotation(QualType Ty,
                                             DxilTypeSystem &dxilTypeSys,
@@ -1003,7 +1044,8 @@ unsigned CGMSHLSLRuntime::AddTypeAnnotation(QualType Ty,
       unsigned structSize = annotation->GetCBufferSize();
       return structSize;
     }
-    DxilStructAnnotation *annotation = dxilTypeSys.AddStructAnnotation(ST);
+    DxilStructAnnotation *annotation = dxilTypeSys.AddStructAnnotation(ST,
+      GetNumTemplateArgsForRecordDecl(RT->getDecl()));
 
     return ConstructStructAnnotation(annotation, RD, dxilTypeSys);
   } else if (const RecordType *RT = dyn_cast<RecordType>(paramTy)) {
@@ -1015,7 +1057,8 @@ unsigned CGMSHLSLRuntime::AddTypeAnnotation(QualType Ty,
       unsigned structSize = annotation->GetCBufferSize();
       return structSize;
     }
-    DxilStructAnnotation *annotation = dxilTypeSys.AddStructAnnotation(ST);
+    DxilStructAnnotation *annotation = dxilTypeSys.AddStructAnnotation(ST,
+      GetNumTemplateArgsForRecordDecl(RT->getDecl()));
 
     return ConstructStructAnnotation(annotation, RD, dxilTypeSys);
   } else if (IsHLSLResourceType(Ty)) {
@@ -1064,6 +1107,10 @@ static DxilResource::Kind KeywordToKind(StringRef keyword) {
     return DxilResource::Kind::Texture2D;
   if (keyword == "Texture2DMS" || keyword == "RWTexture2DMS")
     return DxilResource::Kind::Texture2DMS;
+  if (keyword == "FeedbackTexture2DMinLOD")
+    return DxilResource::Kind::FeedbackTexture2DMinLOD;
+  if (keyword == "FeedbackTexture2DTiled")
+    return DxilResource::Kind::FeedbackTexture2DTiled;
   if (keyword == "Texture3D" || keyword == "RWTexture3D" || keyword == "RasterizerOrderedTexture3D")
     return DxilResource::Kind::Texture3D;
   if (keyword == "TextureCube" || keyword == "RWTextureCube")
@@ -1073,6 +1120,10 @@ static DxilResource::Kind KeywordToKind(StringRef keyword) {
     return DxilResource::Kind::Texture1DArray;
   if (keyword == "Texture2DArray" || keyword == "RWTexture2DArray" || keyword == "RasterizerOrderedTexture2DArray")
     return DxilResource::Kind::Texture2DArray;
+  if (keyword == "FeedbackTexture2DArrayMinLOD")
+    return DxilResource::Kind::FeedbackTexture2DArrayMinLOD;
+  if (keyword == "FeedbackTexture2DArrayTiled")
+    return DxilResource::Kind::FeedbackTexture2DArrayTiled;
   if (keyword == "Texture2DMSArray" || keyword == "RWTexture2DMSArray")
     return DxilResource::Kind::Texture2DMSArray;
   if (keyword == "TextureCubeArray" || keyword == "RWTextureCubeArray")
@@ -1182,6 +1233,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
   bool isVS = false;
   bool isPS = false;
   bool isRay = false;
+  bool isMS = false;
+  bool isAS = false;
   if (const HLSLShaderAttr *Attr = FD->getAttr<HLSLShaderAttr>()) {
     // Stage is already validate in HandleDeclAttributeForHLSL.
     // Here just check first letter (or two).
@@ -1233,12 +1286,32 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
       funcProps->shaderKind = DXIL::ShaderKind::Intersection;
       break;
     case 'a':
-      isRay = true;
-      funcProps->shaderKind = DXIL::ShaderKind::AnyHit;
+      switch (Attr->getStage()[1]) {
+      case 'm':
+        isAS = true;
+        funcProps->shaderKind = DXIL::ShaderKind::Amplification;
+        break;
+      case 'n':
+        isRay = true;
+        funcProps->shaderKind = DXIL::ShaderKind::AnyHit;
+        break;
+      default:
+        break;
+      }
       break;
     case 'm':
-      isRay = true;
-      funcProps->shaderKind = DXIL::ShaderKind::Miss;
+      switch (Attr->getStage()[1]) {
+      case 'e':
+        isMS = true;
+        funcProps->shaderKind = DXIL::ShaderKind::Mesh;
+        break;
+      case 'i':
+        isRay = true;
+        funcProps->shaderKind = DXIL::ShaderKind::Miss;
+        break;
+      default:
+        break;
+      }
       break;
     default:
       break;
@@ -1289,6 +1362,12 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
   const ShaderModel *SM = m_pHLModule->GetShaderModel();
   if (isEntry) {
     funcProps->shaderKind = SM->GetKind();
+    if (funcProps->shaderKind == DXIL::ShaderKind::Mesh) {
+      isMS = true;
+    }
+    else if (funcProps->shaderKind == DXIL::ShaderKind::Amplification) {
+      isAS = true;
+    }
   }
 
   // Geometry shader.
@@ -1325,16 +1404,26 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
 
   // Computer shader.
   if (const HLSLNumThreadsAttr *Attr = FD->getAttr<HLSLNumThreadsAttr>()) {
-    isCS = true;
-    funcProps->shaderKind = DXIL::ShaderKind::Compute;
+    if (isMS) {
+      funcProps->ShaderProps.MS.numThreads[0] = Attr->getX();
+      funcProps->ShaderProps.MS.numThreads[1] = Attr->getY();
+      funcProps->ShaderProps.MS.numThreads[2] = Attr->getZ();
+    } else if (isAS) {
+      funcProps->ShaderProps.AS.numThreads[0] = Attr->getX();
+      funcProps->ShaderProps.AS.numThreads[1] = Attr->getY();
+      funcProps->ShaderProps.AS.numThreads[2] = Attr->getZ();
+    } else {
+      isCS = true;
+      funcProps->shaderKind = DXIL::ShaderKind::Compute;
 
-    funcProps->ShaderProps.CS.numThreads[0] = Attr->getX();
-    funcProps->ShaderProps.CS.numThreads[1] = Attr->getY();
-    funcProps->ShaderProps.CS.numThreads[2] = Attr->getZ();
+      funcProps->ShaderProps.CS.numThreads[0] = Attr->getX();
+      funcProps->ShaderProps.CS.numThreads[1] = Attr->getY();
+      funcProps->ShaderProps.CS.numThreads[2] = Attr->getZ();
+    }
 
-    if (isEntry && !SM->IsCS()) {
+    if (isEntry && !SM->IsCS() && !SM->IsMS() && !SM->IsAS()) {
       unsigned DiagID = Diags.getCustomDiagID(
-          DiagnosticsEngine::Error, "attribute numthreads only valid for CS.");
+          DiagnosticsEngine::Error, "attribute numthreads only valid for CS/MS/AS.");
       Diags.Report(Attr->getLocation(), DiagID);
       return;
     }
@@ -1398,10 +1487,16 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
       DXIL::TessellatorOutputPrimitive primitive =
           StringToTessOutputPrimitive(Attr->getTopology());
       funcProps->ShaderProps.HS.outputPrimitive = primitive;
-    } else if (isEntry && !SM->IsHS()) {
+    }
+    else if (isMS) {
+      DXIL::MeshOutputTopology topology =
+          StringToMeshOutputTopology(Attr->getTopology());
+      funcProps->ShaderProps.MS.outputTopology = topology;
+    }
+    else if (isEntry && !SM->IsHS() && !SM->IsMS()) {
       unsigned DiagID =
           Diags.getCustomDiagID(DiagnosticsEngine::Warning,
-                                "attribute outputtopology only valid for HS.");
+                                "attribute outputtopology only valid for HS and MS.");
       Diags.Report(Attr->getLocation(), DiagID);
     }
   }
@@ -1478,7 +1573,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
     funcProps->shaderKind = DXIL::ShaderKind::Pixel;
   }
 
-  const unsigned profileAttributes = isCS + isHS + isDS + isGS + isVS + isPS + isRay;
+  const unsigned profileAttributes = isCS + isHS + isDS + isGS + isVS + isPS + isRay + isMS + isAS;
 
   // TODO: check this in front-end and report error.
   DXASSERT(profileAttributes < 2, "profile attributes are mutual exclusive");
@@ -1491,6 +1586,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
     case ShaderModel::Kind::Geometry:
     case ShaderModel::Kind::Vertex:
     case ShaderModel::Kind::Pixel:
+    case ShaderModel::Kind::Mesh:
+    case ShaderModel::Kind::Amplification:
       DXASSERT(funcProps->shaderKind == SM->GetKind(),
                "attribute profile not match entry function profile");
       break;
@@ -1561,6 +1658,10 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
     funcProps->ShaderProps.Ray.attributeSizeInBytes = 0;
   }
 
+  bool hasOutIndices = false;
+  bool hasOutVertices = false;
+  bool hasOutPrimitives = false;
+  bool hasInPayload = false;
   for (; ArgNo < F->arg_size(); ++ArgNo, ++ParmIdx) {
     DxilParameterAnnotation &paramAnnotation =
         FuncAnnotation->GetParameterAnnotation(ArgNo);
@@ -1591,6 +1692,155 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
     if (parmDecl->hasAttr<HLSLOutAttr>() && parmDecl->hasAttr<HLSLInAttr>())
       dxilInputQ = DxilParamInputQual::Inout;
 
+    if (parmDecl->hasAttr<HLSLOutAttr>() && parmDecl->hasAttr<HLSLIndicesAttr>()) {
+      if (hasOutIndices) {
+        unsigned DiagID = Diags.getCustomDiagID(
+            DiagnosticsEngine::Error,
+            "multiple out indices parameters not allowed");
+        Diags.Report(parmDecl->getLocation(), DiagID);
+        continue;
+      }
+      const ConstantArrayType *CAT = dyn_cast<ConstantArrayType>(fieldTy.getCanonicalType());
+      if (CAT == nullptr) {
+        unsigned DiagID = Diags.getCustomDiagID(
+          DiagnosticsEngine::Error,
+          "indices output is not an constant-length array");
+        Diags.Report(parmDecl->getLocation(), DiagID);
+        continue;
+      }
+      unsigned count = CAT->getSize().getZExtValue();
+      if (count > DXIL::kMaxMSOutputPrimitiveCount) {
+        unsigned DiagID = Diags.getCustomDiagID(
+            DiagnosticsEngine::Error,
+            "max primitive count should not exceed %0");
+        Diags.Report(parmDecl->getLocation(), DiagID) << DXIL::kMaxMSOutputPrimitiveCount;
+        continue;
+      }
+      if (funcProps->ShaderProps.MS.maxPrimitiveCount != 0 &&
+        funcProps->ShaderProps.MS.maxPrimitiveCount != count) {
+        unsigned DiagID = Diags.getCustomDiagID(
+            DiagnosticsEngine::Error,
+            "max primitive count mismatch");
+        Diags.Report(parmDecl->getLocation(), DiagID);
+        continue;
+      }
+      // Get element type.
+      QualType arrayEleTy = CAT->getElementType();
+
+      if (hlsl::IsHLSLVecType(arrayEleTy)) {
+        QualType vecEltTy = hlsl::GetHLSLVecElementType(arrayEleTy);
+        if (!vecEltTy->isUnsignedIntegerType() || CGM.getContext().getTypeSize(vecEltTy) != 32) {
+          unsigned DiagID = Diags.getCustomDiagID(
+              DiagnosticsEngine::Error,
+              "the element of out_indices array must be uint2 for line output or uint3 for triangle output");
+          Diags.Report(parmDecl->getLocation(), DiagID);
+          continue;
+        }
+        unsigned vecEltCount = hlsl::GetHLSLVecSize(arrayEleTy);
+        if (funcProps->ShaderProps.MS.outputTopology == DXIL::MeshOutputTopology::Line && vecEltCount != 2) {
+          unsigned DiagID = Diags.getCustomDiagID(
+              DiagnosticsEngine::Error,
+              "the element of out_indices array in a mesh shader whose output topology is line must be uint2");
+          Diags.Report(parmDecl->getLocation(), DiagID);
+          continue;
+        }
+        if (funcProps->ShaderProps.MS.outputTopology == DXIL::MeshOutputTopology::Triangle && vecEltCount != 3) {
+          unsigned DiagID = Diags.getCustomDiagID(
+              DiagnosticsEngine::Error,
+              "the element of out_indices array in a mesh shader whose output topology is triangle must be uint3");
+          Diags.Report(parmDecl->getLocation(), DiagID);
+          continue;
+        }
+      } else {
+        unsigned DiagID = Diags.getCustomDiagID(
+            DiagnosticsEngine::Error,
+            "the element of out_indices array must be uint2 for line output or uint3 for triangle output");
+        Diags.Report(parmDecl->getLocation(), DiagID);
+        continue;
+      }
+
+      dxilInputQ = DxilParamInputQual::OutIndices;
+      funcProps->ShaderProps.MS.maxPrimitiveCount = count;
+      hasOutIndices = true;
+    }
+    if (parmDecl->hasAttr<HLSLOutAttr>() && parmDecl->hasAttr<HLSLVerticesAttr>()) {
+      if (hasOutVertices) {
+        unsigned DiagID = Diags.getCustomDiagID(
+            DiagnosticsEngine::Error,
+            "multiple out vertices parameters not allowed");
+        Diags.Report(parmDecl->getLocation(), DiagID);
+        continue;
+      }
+      const ConstantArrayType *CAT = dyn_cast<ConstantArrayType>(fieldTy.getCanonicalType());
+      if (CAT == nullptr) {
+        unsigned DiagID = Diags.getCustomDiagID(
+            DiagnosticsEngine::Error,
+            "vertices output is not an constant-length array");
+        Diags.Report(parmDecl->getLocation(), DiagID);
+        continue;
+      }
+      unsigned count = CAT->getSize().getZExtValue();
+      if (count > DXIL::kMaxMSOutputVertexCount) {
+        unsigned DiagID = Diags.getCustomDiagID(
+            DiagnosticsEngine::Error,
+            "max vertex count should not exceed %0");
+        Diags.Report(parmDecl->getLocation(), DiagID) << DXIL::kMaxMSOutputVertexCount;
+        continue;
+      }
+
+      dxilInputQ = DxilParamInputQual::OutVertices;
+      funcProps->ShaderProps.MS.maxVertexCount = count;
+      hasOutVertices = true;
+    }
+    if (parmDecl->hasAttr<HLSLOutAttr>() && parmDecl->hasAttr<HLSLPrimitivesAttr>()) {
+      if (hasOutPrimitives) {
+        unsigned DiagID = Diags.getCustomDiagID(
+            DiagnosticsEngine::Error,
+            "multiple out primitives parameters not allowed");
+        Diags.Report(parmDecl->getLocation(), DiagID);
+        continue;
+      }
+      const ConstantArrayType *CAT = dyn_cast<ConstantArrayType>(fieldTy.getCanonicalType());
+      if (CAT == nullptr) {
+        unsigned DiagID = Diags.getCustomDiagID(
+            DiagnosticsEngine::Error,
+            "primitives output is not an constant-length array");
+        Diags.Report(parmDecl->getLocation(), DiagID);
+        continue;
+      }
+      unsigned count = CAT->getSize().getZExtValue();
+      if (count > DXIL::kMaxMSOutputPrimitiveCount) {
+        unsigned DiagID = Diags.getCustomDiagID(
+            DiagnosticsEngine::Error,
+            "max primitive count should not exceed %0");
+        Diags.Report(parmDecl->getLocation(), DiagID) << DXIL::kMaxMSOutputPrimitiveCount;
+        continue;
+      }
+      if (funcProps->ShaderProps.MS.maxPrimitiveCount != 0 &&
+        funcProps->ShaderProps.MS.maxPrimitiveCount != count) {
+        unsigned DiagID = Diags.getCustomDiagID(
+            DiagnosticsEngine::Error,
+            "max primitive count mismatch");
+        Diags.Report(parmDecl->getLocation(), DiagID);
+        continue;
+      }
+
+      dxilInputQ = DxilParamInputQual::OutPrimitives;
+      funcProps->ShaderProps.MS.maxPrimitiveCount = count;
+      hasOutPrimitives = true;
+    }
+    if (parmDecl->hasAttr<HLSLInAttr>() && parmDecl->hasAttr<HLSLPayloadAttr>()) {
+      if (hasInPayload) {
+        unsigned DiagID = Diags.getCustomDiagID(
+            DiagnosticsEngine::Error,
+            "multiple in payload parameters not allowed");
+        Diags.Report(parmDecl->getLocation(), DiagID);
+        continue;
+      }
+      dxilInputQ = DxilParamInputQual::InPayload;
+      hasInPayload = true;
+    }
+
     DXIL::InputPrimitive inputPrimitive = DXIL::InputPrimitive::Undefined;
 
     if (IsHLSLOutputPatchType(parmDecl->getType())) {
@@ -2286,6 +2536,10 @@ static DxilResourceBase::Class KeywordToClass(const std::string &keyword) {
   isUAV |= keyword == "RasterizerOrderedTexture2D";
   isUAV |= keyword == "RasterizerOrderedTexture2DArray";
   isUAV |= keyword == "RasterizerOrderedTexture3D";
+  isUAV |= keyword == "FeedbackTexture2DMinLOD";
+  isUAV |= keyword == "FeedbackTexture2DTiled";
+  isUAV |= keyword == "FeedbackTexture2DArrayMinLOD";
+  isUAV |= keyword == "FeedbackTexture2DArrayTiled";
   if (isUAV)
     return DxilResourceBase::Class::UAV;
 
@@ -2616,6 +2870,8 @@ bool CGMSHLSLRuntime::SetUAVSRV(SourceLocation loc,
   RecordDecl *RD = QualTy->getAs<RecordType>()->getDecl();
 
   hlsl::DxilResource::Kind kind = KeywordToKind(RD->getName());
+  DXASSERT_NOMSG(kind != hlsl::DxilResource::Kind::Invalid);
+
   hlslRes->SetKind(kind);
   
   QualType resultTy = hlsl::GetHLSLResourceResultType(QualTy);
@@ -3633,10 +3889,8 @@ static void AddOpcodeParamForIntrinsic(HLModule &HLM, Function *F,
     llvm::Type *Ty = paramTyList[i];
     if (Ty->isPointerTy()) {
       Ty = Ty->getPointerElementType();
-      if (dxilutil::IsHLSLObjectType(Ty) &&
-          // StreamOutput don't need handle.
-          !HLModule::IsStreamOutputType(Ty)) {
-        // Use handle type for object type.
+      if (dxilutil::IsHLSLResourceType(Ty)) {
+        // Use handle type for resource type.
         // This will make sure temp object variable only used by createHandle.
         paramTyList[i] = HandleTy;
       }
@@ -3694,7 +3948,7 @@ static void AddOpcodeParamForIntrinsic(HLModule &HLM, Function *F,
     gep_type_iterator GEPIt = gep_type_begin(objGEP), E = gep_type_end(objGEP);
     llvm::Type *resTy = nullptr;
     while (GEPIt != E) {
-      if (dxilutil::IsHLSLObjectType(*GEPIt)) {
+      if (dxilutil::IsHLSLResourceType(*GEPIt)) {
         resTy = *GEPIt;
         break;
       }
@@ -3777,9 +4031,7 @@ static void AddOpcodeParamForIntrinsic(HLModule &HLM, Function *F,
       llvm::Type *Ty = arg->getType();
       if (Ty->isPointerTy()) {
         Ty = Ty->getPointerElementType();
-        if (dxilutil::IsHLSLObjectType(Ty) &&
-          // StreamOutput don't need handle.
-          !HLModule::IsStreamOutputType(Ty)) {
+        if (dxilutil::IsHLSLResourceType(Ty)) {
           // Use object type directly, not by pointer.
           // This will make sure temp object variable only used by ld/st.
           if (GEPOperator *argGEP = dyn_cast<GEPOperator>(arg)) {
@@ -4791,8 +5043,6 @@ static void CreateWriteEnabledStaticGlobals(llvm::Module *M,
   }
 }
 
-
-
 void CGMSHLSLRuntime::FinishCodeGen() {
   // Library don't have entry.
   if (!m_bIsLib) {

+ 21 - 0
tools/clang/lib/Parse/ParseDecl.cpp

@@ -763,6 +763,10 @@ void Parser::ParseGNUAttributeArgs(IdentifierInfo *AttrName,
     //case AttributeList::AT_HLSLLineAdj:
     //case AttributeList::AT_HLSLTriangle:
     //case AttributeList::AT_HLSLTriangleAdj:
+    //case AttributeList::AT_HLSLIndices:
+    //case AttributeList::AT_HLSLVertices:
+    //case AttributeList::AT_HLSLPrimitives:
+    //case AttributeList::AT_HLSLPayload:
       goto GenericAttributeParse;
     default:
       Diag(AttrNameLoc, diag::err_hlsl_unsupported_construct) << AttrName;
@@ -3786,9 +3790,15 @@ HLSLReservedKeyword:
     case tok::kw_sample:
     case tok::kw_globallycoherent:
     case tok::kw_center:
+    case tok::kw_indices:
+    case tok::kw_vertices:
+    case tok::kw_primitives:
+    case tok::kw_payload:
       // Back-compat: 'precise', 'globallycoherent', 'center' and 'sample' are keywords when used as an interpolation
       // modifiers, but in FXC they can also be used an identifiers. If the decl type has already been specified
       // we need to update the token to be handled as an identifier.
+      // Similarly 'indices', 'vertices', 'primitives' and 'payload' are keywords
+      // when used as a type qualifer in mesh shader, but may still be used as a variable name.
       if (getLangOpts().HLSL) {
         if (DS.getTypeSpecType() != DeclSpec::TST_unspecified) {
           Tok.setKind(tok::identifier);
@@ -5236,6 +5246,10 @@ bool Parser::isDeclarationSpecifier(bool DisambiguatingWithExpression) {
   case tok::kw_triangle:
   case tok::kw_triangleadj:
   case tok::kw_export:
+  case tok::kw_indices:
+  case tok::kw_vertices:
+  case tok::kw_primitives:
+  case tok::kw_payload:
     return true;
   // HLSL Change Ends
 
@@ -6006,6 +6020,9 @@ void Parser::ParseDirectDeclarator(Declarator &D) {
     // FXC they can also be used an identifiers. If the next token is a
     // punctuator, then we are using them as identifers. Need to change
     // the token type to tok::identifier and fall through to the next case.
+    // Similarly 'indices', 'vertices', 'primitives' and 'payload' are keywords
+    // when used as a type qualifer in mesh shader, but may still be used as a
+    // variable name.
     // E.g., <type> left, center, right;
     if (getLangOpts().HLSL) {
       switch (Tok.getKind()) {
@@ -6013,6 +6030,10 @@ void Parser::ParseDirectDeclarator(Declarator &D) {
       case tok::kw_globallycoherent:
       case tok::kw_precise:
       case tok::kw_sample:
+      case tok::kw_indices:
+      case tok::kw_vertices:
+      case tok::kw_primitives:
+      case tok::kw_payload:
         if (tok::isPunctuator(NextToken().getKind()))
           Tok.setKind(tok::identifier);
         break;

+ 10 - 0
tools/clang/lib/Parse/ParseExpr.cpp

@@ -794,9 +794,15 @@ HLSLReservedKeyword:
   case tok::kw_sample:
   case tok::kw_globallycoherent:
   case tok::kw_center:
+  case tok::kw_indices:
+  case tok::kw_vertices:
+  case tok::kw_primitives:
+  case tok::kw_payload:
     // Back-compat: 'precise', 'globallycoherent', 'center' and 'sample' are keywords when used as an interpolation 
     // modifiers, but in FXC they can also be used an identifiers. No interpolation modifiers are expected here
     // so we need to change the token type to tok::identifier and fall through to the next case.
+    // Similarly 'indices', 'vertices', 'primitives' and 'payload' are keywords when used
+    // as a type qualifer in mesh shader, but may still be used as a variable name.
     Tok.setKind(tok::identifier);
     __fallthrough;
     // HLSL Change Ends
@@ -1722,6 +1728,10 @@ Parser::ParsePostfixExpressionSuffix(ExprResult LHS) {
         case tok::kw_globallycoherent:
         case tok::kw_precise:
         case tok::kw_sample:
+        case tok::kw_indices:
+        case tok::kw_vertices:
+        case tok::kw_primitives:
+        case tok::kw_payload:
           Tok.setKind(tok::identifier);
           Tok.setIdentifierInfo(PP.getIdentifierInfo(getKeywordSpelling(tk)));
           break;

+ 5 - 1
tools/clang/lib/Parse/ParseStmt.cpp

@@ -179,7 +179,11 @@ Retry:
   case tok::kw_precise:
   case tok::kw_sample:
   case tok::kw_globallycoherent:
-  case tok::kw_center: {
+  case tok::kw_center:
+  case tok::kw_indices:
+  case tok::kw_vertices:
+  case tok::kw_primitives:
+  case tok::kw_payload: {
     // FXC compatiblity: these are keywords when used as modifiers, but in
     // FXC they can also be used an identifiers. If the next token is a
     // punctuator, then we are using them as identifers. Need to change

+ 4 - 0
tools/clang/lib/Parse/ParseTentative.cpp

@@ -1275,6 +1275,10 @@ Parser::isCXXDeclarationSpecifier(Parser::TPResult BracedCastResult,
   case tok::kw_precise:
   case tok::kw_center:
   case tok::kw_globallycoherent:
+  case tok::kw_indices:
+  case tok::kw_vertices:
+  case tok::kw_primitives:
+  case tok::kw_payload:
     // FXC compatiblity: these are keywords when used as modifiers, but in
     // FXC they can also be used an identifiers. If the next token is a
     // punctuator, then we are using them as identifers. Need to change

+ 14 - 6
tools/clang/lib/SPIRV/CapabilityVisitor.cpp

@@ -273,8 +273,9 @@ bool CapabilityVisitor::visit(SpirvDecoration *decor) {
       break;
     }
     case spv::BuiltIn::PrimitiveId: {
-      // PrimitiveID can be used as PSIn
-      if (shaderModel == spv::ExecutionModel::Fragment)
+      // PrimitiveID can be used as PSIn or MSPOut.
+      if (shaderModel == spv::ExecutionModel::Fragment ||
+          shaderModel == spv::ExecutionModel::MeshNV)
         addCapability(spv::Capability::Geometry);
       break;
     }
@@ -285,8 +286,9 @@ bool CapabilityVisitor::visit(SpirvDecoration *decor) {
         addExtension(Extension::EXT_shader_viewport_index_layer,
                      "SV_RenderTargetArrayIndex", loc);
         addCapability(spv::Capability::ShaderViewportIndexLayerEXT);
-      } else if (shaderModel == spv::ExecutionModel::Fragment) {
-        // SV_RenderTargetArrayIndex can be used as PSIn.
+      } else if (shaderModel == spv::ExecutionModel::Fragment ||
+                 shaderModel == spv::ExecutionModel::MeshNV) {
+        // SV_RenderTargetArrayIndex can be used as PSIn or MSPOut.
         addCapability(spv::Capability::Geometry);
       }
       break;
@@ -299,8 +301,9 @@ bool CapabilityVisitor::visit(SpirvDecoration *decor) {
                      "SV_ViewPortArrayIndex", loc);
         addCapability(spv::Capability::ShaderViewportIndexLayerEXT);
       } else if (shaderModel == spv::ExecutionModel::Fragment ||
-                 shaderModel == spv::ExecutionModel::Geometry) {
-        // SV_ViewportArrayIndex can be used as PSIn.
+                 shaderModel == spv::ExecutionModel::Geometry ||
+                 shaderModel == spv::ExecutionModel::MeshNV) {
+        // SV_ViewportArrayIndex can be used as PSIn or GSOut or MSPOut.
         addCapability(spv::Capability::MultiViewport);
       }
       break;
@@ -503,6 +506,11 @@ bool CapabilityVisitor::visit(SpirvEntryPoint *entryPoint) {
     addCapability(spv::Capability::RayTracingNV);
     addExtension(Extension::NV_ray_tracing, "SPV_NV_ray_tracing", {});
     break;
+  case spv::ExecutionModel::MeshNV:
+  case spv::ExecutionModel::TaskNV:
+    addCapability(spv::Capability::MeshShadingNV);
+    addExtension(Extension::NV_mesh_shader, "SPV_NV_mesh_shader", {});
+    break;
   default:
     llvm_unreachable("found unknown shader model");
     break;

+ 243 - 49
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -21,6 +21,7 @@
 #include "llvm/ADT/StringSet.h"
 #include "llvm/Support/Casting.h"
 
+#include "AlignmentSizeCalculator.h"
 #include "SpirvEmitter.h"
 
 namespace clang {
@@ -295,6 +296,15 @@ hlsl::DxilParamInputQual deduceParamQual(const DeclaratorDecl *decl,
   if (hasGSPrimitiveTypeQualifier(decl))
     return hlsl::DxilParamInputQual::InputPrimitive;
 
+  if (decl->hasAttr<HLSLIndicesAttr>())
+    return hlsl::DxilParamInputQual::OutIndices;
+  if (decl->hasAttr<HLSLVerticesAttr>())
+    return hlsl::DxilParamInputQual::OutVertices;
+  if (decl->hasAttr<HLSLPrimitivesAttr>())
+    return hlsl::DxilParamInputQual::OutPrimitives;
+  if (decl->hasAttr<HLSLPayloadAttr>())
+    return hlsl::DxilParamInputQual::InPayload;
+
   return asInput ? hlsl::DxilParamInputQual::In : hlsl::DxilParamInputQual::Out;
 }
 
@@ -432,12 +442,35 @@ bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl,
                                               SpirvInstruction *storedValue,
                                               bool forPCF) {
   QualType type = getTypeOrFnRetType(decl);
+  uint32_t arraySize = 0;
 
   // Output stream types (PointStream, LineStream, TriangleStream) are
   // translated as their underlying struct types.
   if (hlsl::IsHLSLStreamOutputType(type))
     type = hlsl::GetHLSLResourceResultType(type);
 
+  if (decl->hasAttr<HLSLIndicesAttr>() || decl->hasAttr<HLSLVerticesAttr>() ||
+      decl->hasAttr<HLSLPrimitivesAttr>()) {
+    const auto *typeDecl = astContext.getAsConstantArrayType(type);
+    type = typeDecl->getElementType();
+    arraySize = static_cast<uint32_t>(typeDecl->getSize().getZExtValue());
+    if (decl->hasAttr<HLSLIndicesAttr>()) {
+      // create SPIR-V builtin array PrimitiveIndicesNV of type
+      // "uint [MaxPrimitiveCount * verticesPerPrim]"
+      uint32_t verticesPerPrim = 1;
+      if (!isVectorType(type, nullptr, &verticesPerPrim)) {
+        assert(isScalarType(type));
+      }
+      arraySize = arraySize * verticesPerPrim;
+      QualType arrayType = astContext.getConstantArrayType(
+          astContext.UnsignedIntTy, llvm::APInt(32, arraySize),
+          clang::ArrayType::Normal, 0);
+      stageVarInstructions[cast<DeclaratorDecl>(decl)] = getBuiltinVar(
+          spv::BuiltIn::PrimitiveIndicesNV, arrayType, decl->getLocation());
+      return true;
+    }
+  }
+
   const auto *sigPoint = deduceSigPoint(
       decl, /*asInput=*/false, spvContext.getCurrentShaderModelKind(), forPCF);
 
@@ -453,11 +486,12 @@ bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl,
   // Write back of stage output variables in GS is manually controlled by
   // .Append() intrinsic method, implemented in writeBackOutputStream(). So
   // ignoreValue should be set to true for GS.
-  const bool noWriteBack = storedValue == nullptr || spvContext.isGS();
+  const bool noWriteBack =
+      storedValue == nullptr || spvContext.isGS() || spvContext.isMS();
 
-  return createStageVars(sigPoint, decl, /*asInput=*/false, type,
-                         /*arraySize=*/0, "out.var", llvm::None, &storedValue,
-                         noWriteBack, &inheritSemantic);
+  return createStageVars(sigPoint, decl, /*asInput=*/false, type, arraySize,
+                         "out.var", llvm::None, &storedValue, noWriteBack,
+                         &inheritSemantic);
 }
 
 bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl,
@@ -505,9 +539,15 @@ bool DeclResultIdMapper::createStageInputVar(const ParmVarDecl *paramDecl,
 
   SemanticInfo inheritSemantic = {};
 
-  return createStageVars(sigPoint, paramDecl, /*asInput=*/true, type, arraySize,
-                         "in.var", llvm::None, loadedValue,
-                         /*noWriteBack=*/false, &inheritSemantic);
+  if (paramDecl->hasAttr<HLSLPayloadAttr>()) {
+    spv::StorageClass sc = getStorageClassForSigPoint(sigPoint);
+    return createPayloadStageVars(sigPoint, sc, paramDecl, /*asInput=*/true,
+                                  type, "in.var", loadedValue);
+  } else {
+    return createStageVars(sigPoint, paramDecl, /*asInput=*/true, type,
+                           arraySize, "in.var", llvm::None, loadedValue,
+                           /*noWriteBack=*/false, &inheritSemantic);
+  }
 }
 
 const DeclResultIdMapper::DeclSpirvInfo *
@@ -1193,9 +1233,10 @@ bool DeclResultIdMapper::checkSemanticDuplication(bool forInput) {
     auto s = var.getSemanticStr();
 
     if (s.empty()) {
-      // We translate WaveGetLaneCount() and WaveGetLaneIndex() into builtin
-      // variables. Those variables are inserted into the normal stage IO
-      // processing pipeline, but with the semantics as empty strings.
+      // We translate WaveGetLaneCount(), WaveGetLaneIndex() and 'payload' param
+      // block declaration into builtin variables. Those variables are inserted
+      // into the normal stage IO processing pipeline, but with the semantics as
+      // empty strings.
       assert(var.isSpirvBuitin());
       continue;
     }
@@ -1671,17 +1712,31 @@ bool DeclResultIdMapper::createStageVars(
       return false;
 
     const auto semanticKind = semanticToUse->getKind();
+    const auto sigPointKind = sigPoint->GetKind();
 
     // Error out when the given semantic is invalid in this shader model
-    if (hlsl::SigPoint::GetInterpretation(semanticKind, sigPoint->GetKind(),
+    if (hlsl::SigPoint::GetInterpretation(semanticKind, sigPointKind,
                                           spvContext.getMajorVersion(),
                                           spvContext.getMinorVersion()) ==
         hlsl::DXIL::SemanticInterpretationKind::NA) {
-      emitError("invalid usage of semantic '%0' in shader profile %1", loc)
-          << semanticToUse->str
-          << hlsl::ShaderModel::GetKindName(
-                 spvContext.getCurrentShaderModelKind());
-      return false;
+      // Special handle MSIn/ASIn allowing VK-only builtin "DrawIndex".
+      switch (sigPointKind) {
+      case hlsl::SigPoint::Kind::MSIn:
+      case hlsl::SigPoint::Kind::ASIn:
+        if (const auto *builtinAttr = decl->getAttr<VKBuiltInAttr>()) {
+          const llvm::StringRef builtin = builtinAttr->getBuiltIn();
+          if (builtin == "DrawIndex") {
+            break;
+          }
+        }
+        // fall through
+      default:
+        emitError("invalid usage of semantic '%0' in shader profile %1", loc)
+            << semanticToUse->str
+            << hlsl::ShaderModel::GetKindName(
+                   spvContext.getCurrentShaderModelKind());
+        return false;
+      }
     }
 
     if (!validateVKBuiltins(decl, sigPoint))
@@ -1711,9 +1766,9 @@ bool DeclResultIdMapper::createStageVars(
     // * SV_ShadingRate is a uint value, but the builtin it corresponds to is a
     //   int2.
 
-    if (glPerVertex.tryToAccess(sigPoint->GetKind(), semanticKind,
+    if (glPerVertex.tryToAccess(sigPointKind, semanticKind,
                                 semanticToUse->index, invocationId, value,
-                                noWriteBack, loc))
+                                noWriteBack, /*vecComponent=*/nullptr, loc))
       return true;
 
     switch (semanticKind) {
@@ -1757,7 +1812,7 @@ bool DeclResultIdMapper::createStageVars(
 
     // Boolean stage I/O variables must be represented as unsigned integers.
     // Boolean built-in variables are represented as bool.
-    if (isBooleanStageIOVar(decl, type, semanticKind, sigPoint->GetKind())) {
+    if (isBooleanStageIOVar(decl, type, semanticKind, sigPointKind)) {
       evalType = getUintTypeWithSourceComponents(astContext, type);
     }
 
@@ -1796,8 +1851,15 @@ bool DeclResultIdMapper::createStageVars(
 
     // TODO: the following may not be correct?
     if (sigPoint->GetSignatureKind() ==
-        hlsl::DXIL::SignatureKind::PatchConstant)
-      spvBuilder.decoratePatch(varInstr, varInstr->getSourceLocation());
+        hlsl::DXIL::SignatureKind::PatchConstOrPrim) {
+      if (sigPointKind == hlsl::SigPoint::Kind::MSPOut) {
+        // Decorate with PerPrimitiveNV for per-primitive out variables.
+        spvBuilder.decoratePerPrimitiveNV(varInstr,
+                                          varInstr->getSourceLocation());
+      } else {
+        spvBuilder.decoratePatch(varInstr, varInstr->getSourceLocation());
+      }
+    }
 
     // Decorate with interpolation modes for pixel shader input variables
     if (spvContext.isPS() && sigPoint->IsInput() &&
@@ -1947,7 +2009,7 @@ bool DeclResultIdMapper::createStageVars(
 
       // Since boolean stage input variables are represented as unsigned
       // integers, after loading them, we should cast them to boolean.
-      if (isBooleanStageIOVar(decl, type, semanticKind, sigPoint->GetKind())) {
+      if (isBooleanStageIOVar(decl, type, semanticKind, sigPointKind)) {
         *value =
             theEmitter.castToType(*value, evalType, type, thisSemantic.loc);
       }
@@ -2026,8 +2088,7 @@ bool DeclResultIdMapper::createStageVars(
       }
       // Since boolean output stage variables are represented as unsigned
       // integers, we must cast the value to uint before storing.
-      else if (isBooleanStageIOVar(decl, type, semanticKind,
-                                   sigPoint->GetKind())) {
+      else if (isBooleanStageIOVar(decl, type, semanticKind, sigPointKind)) {
         *value =
             theEmitter.castToType(*value, type, evalType, thisSemantic.loc);
         spvBuilder.createStore(ptr, *value, thisSemantic.loc);
@@ -2153,6 +2214,7 @@ bool DeclResultIdMapper::createStageVars(
     //       but we only write to the struct at the InvocationID index
     // * DS: output is a single struct, without extra arrayness
     // * GS: output is controlled by OpEmitVertex, one vertex per time
+    // * MS: output is an array of structs, with extra arrayness
     //
     // The interesting shader stage is HS. We need the InvocationID to write
     // out the value to the correct array element.
@@ -2174,6 +2236,96 @@ bool DeclResultIdMapper::createStageVars(
   return true;
 }
 
+bool DeclResultIdMapper::createPayloadStageVars(
+    const hlsl::SigPoint *sigPoint, spv::StorageClass sc, const NamedDecl *decl,
+    bool asInput, QualType type, const llvm::StringRef namePrefix,
+    SpirvInstruction **value, uint32_t payloadMemOffset) {
+  assert(spvContext.isMS() || spvContext.isAS());
+  assert(value);
+
+  if (type->isVoidType()) {
+    // No stage variables will be created for void type.
+    return true;
+  }
+
+  const auto loc = decl->getLocation();
+  if (!type->isStructureType()) {
+    StageVar stageVar(sigPoint, /*semaInfo=*/{}, /*builtinAttr=*/nullptr, type,
+                      getLocationCount(astContext, type));
+    const auto name = namePrefix.str() + "." + decl->getNameAsString();
+    SpirvVariable *varInstr =
+        spvBuilder.addStageIOVar(type, sc, name, /*isPrecise=*/false, loc);
+
+    if (!varInstr)
+      return false;
+
+    // Even though these as user defined IO stage variables, set them as SPIR-V
+    // builtins in order to bypass any semantic string checks and location
+    // assignment.
+    stageVar.setIsSpirvBuiltin();
+    stageVar.setSpirvInstr(varInstr);
+    stageVars.push_back(stageVar);
+
+    // Decorate with PerTaskNV for mesh/amplification shader payload variables.
+    spvBuilder.decoratePerTaskNV(varInstr, payloadMemOffset,
+                                 varInstr->getSourceLocation());
+
+    if (asInput) {
+      *value = spvBuilder.createLoad(type, varInstr, loc);
+    } else {
+      spvBuilder.createStore(varInstr, *value, loc);
+    }
+    return true;
+  }
+
+  // This decl translates into multiple stage input/output payload variables
+  // and we need to load/store these individual member variables.
+  const auto *structDecl = type->getAs<RecordType>()->getDecl();
+  llvm::SmallVector<SpirvInstruction *, 4> subValues;
+  AlignmentSizeCalculator alignmentCalc(astContext, spirvOptions);
+  uint32_t nextMemberOffset = 0;
+
+  for (const auto *field : structDecl->fields()) {
+    const auto fieldType = field->getType();
+    SpirvInstruction *subValue = nullptr;
+    uint32_t memberAlignment = 0, memberSize = 0, stride = 0;
+
+    // The next avaiable offset after laying out the previous members.
+    std::tie(memberAlignment, memberSize) = alignmentCalc.getAlignmentAndSize(
+        field->getType(), spirvOptions.ampPayloadLayoutRule,
+        /*isRowMajor*/ llvm::None, &stride);
+    alignmentCalc.alignUsingHLSLRelaxedLayout(
+        field->getType(), memberSize, memberAlignment, &nextMemberOffset);
+
+    // The vk::offset attribute takes precedence over all.
+    if (field->getAttr<VKOffsetAttr>()) {
+      nextMemberOffset = field->getAttr<VKOffsetAttr>()->getOffset();
+    }
+
+    // Each payload member must have an Offset Decoration.
+    payloadMemOffset = nextMemberOffset;
+    nextMemberOffset += memberSize;
+
+    if (!asInput) {
+      subValue = spvBuilder.createCompositeExtract(
+          fieldType, *value, {getNumBaseClasses(type) + field->getFieldIndex()},
+          loc);
+    }
+
+    if (!createPayloadStageVars(sigPoint, sc, field, asInput, field->getType(),
+                                namePrefix, &subValue, payloadMemOffset))
+      return false;
+
+    if (asInput) {
+      subValues.push_back(subValue);
+    }
+  }
+  if (asInput) {
+    *value = spvBuilder.createCompositeConstruct(type, subValues, loc);
+  }
+  return true;
+}
+
 bool DeclResultIdMapper::writeBackOutputStream(const NamedDecl *decl,
                                                QualType type,
                                                SpirvInstruction *value) {
@@ -2191,11 +2343,11 @@ bool DeclResultIdMapper::writeBackOutputStream(const NamedDecl *decl,
     // Found semantic attached directly to this Decl. Write the value for this
     // Decl to the corresponding stage output variable.
 
-    // Handle SV_Position, SV_ClipDistance, and SV_CullDistance
-    if (glPerVertex.tryToAccess(hlsl::DXIL::SigPointKind::GSOut,
-                                semanticInfo.semantic->GetKind(),
-                                semanticInfo.index, llvm::None, &value,
-                                /*noWriteBack=*/false, loc))
+    // Handle SV_ClipDistance, and SV_CullDistance
+    if (glPerVertex.tryToAccess(
+            hlsl::DXIL::SigPointKind::GSOut, semanticInfo.semantic->GetKind(),
+            semanticInfo.index, llvm::None, &value,
+            /*noWriteBack=*/false, /*vecComponent=*/nullptr, loc))
       return true;
 
     // Query the <result-id> for the stage output variable generated out
@@ -2335,6 +2487,8 @@ SpirvVariable *DeclResultIdMapper::getBuiltinVar(spv::BuiltIn builtIn,
   if (builtInVar != builtinToVarMap.end()) {
     return builtInVar->second;
   }
+  spv::StorageClass sc = spv::StorageClass::Max;
+  // Valid builtins supported
   switch (builtIn) {
   case spv::BuiltIn::SubgroupSize:
   case spv::BuiltIn::SubgroupLocalInvocationId:
@@ -2353,7 +2507,12 @@ SpirvVariable *DeclResultIdMapper::getBuiltinVar(spv::BuiltIn builtIn,
   case spv::BuiltIn::WorldToObjectNV:
   case spv::BuiltIn::LaunchIdNV:
   case spv::BuiltIn::LaunchSizeNV:
-    // Valid builtins supported
+    sc = spv::StorageClass::Input;
+    break;
+  case spv::BuiltIn::PrimitiveCountNV:
+  case spv::BuiltIn::PrimitiveIndicesNV:
+  case spv::BuiltIn::TaskCountNV:
+    sc = spv::StorageClass::Output;
     break;
   default:
     assert(false && "unsupported SPIR-V builtin");
@@ -2361,8 +2520,8 @@ SpirvVariable *DeclResultIdMapper::getBuiltinVar(spv::BuiltIn builtIn,
   }
 
   // Create a dummy StageVar for this builtin variable
-  auto var = spvBuilder.addStageBuiltinVar(type, spv::StorageClass::Input,
-                                           builtIn, /*isPrecise*/ false, loc);
+  auto var = spvBuilder.addStageBuiltinVar(type, sc, builtIn,
+                                           /*isPrecise*/ false, loc);
 
   const hlsl::SigPoint *sigPoint =
       hlsl::SigPoint::GetSigPoint(hlsl::SigPointFromInputQual(
@@ -2407,6 +2566,7 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
             .Case("BaseInstance", BuiltIn::BaseInstance)
             .Case("DrawIndex", BuiltIn::DrawIndex)
             .Case("DeviceIndex", BuiltIn::DeviceIndex)
+            .Case("ViewportMaskNV", BuiltIn::ViewportMaskNV)
             .Default(BuiltIn::Max);
 
     assert(spvBuiltIn != BuiltIn::Max); // The frontend should guarantee this.
@@ -2419,9 +2579,9 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
   // each semantic.
   switch (semanticKind) {
   // According to DXIL spec, the Position SV can be used by all SigPoints
-  // other than PCIn, HSIn, GSIn, PSOut, CSIn.
+  // other than PCIn, HSIn, GSIn, PSOut, CSIn, MSIn, MSPOut, ASIn.
   // According to Vulkan spec, the Position BuiltIn can only be used
-  // by VSOut, HS/DS/GS In/Out.
+  // by VSOut, HS/DS/GS In/Out, MSOut.
   case hlsl::Semantic::Kind::Position: {
     switch (sigPointKind) {
     case hlsl::SigPoint::Kind::VSIn:
@@ -2435,6 +2595,7 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
     case hlsl::SigPoint::Kind::DSOut:
     case hlsl::SigPoint::Kind::GSVIn:
     case hlsl::SigPoint::Kind::GSOut:
+    case hlsl::SigPoint::Kind::MSOut:
       stageVar->setIsSpirvBuiltin();
       return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::Position,
                                            isPrecise, srcLoc);
@@ -2497,9 +2658,9 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
                                          isPrecise, srcLoc);
   }
   // According to DXIL spec, the ClipDistance/CullDistance SV can be used by all
-  // SigPoints other than PCIn, HSIn, GSIn, PSOut, CSIn.
-  // According to Vulkan spec, the ClipDistance/CullDistance BuiltIn can only be
-  // used by VSOut, HS/DS/GS In/Out.
+  // SigPoints other than PCIn, HSIn, GSIn, PSOut, CSIn, MSIn, MSPOut, ASIn.
+  // According to Vulkan spec, the ClipDistance/CullDistance
+  // BuiltIn can only be used by VSOut, HS/DS/GS In/Out, MSOut.
   case hlsl::Semantic::Kind::ClipDistance:
   case hlsl::Semantic::Kind::CullDistance: {
     switch (sigPointKind) {
@@ -2515,6 +2676,7 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
     case hlsl::SigPoint::Kind::GSVIn:
     case hlsl::SigPoint::Kind::GSOut:
     case hlsl::SigPoint::Kind::PSIn:
+    case hlsl::SigPoint::Kind::MSOut:
       llvm_unreachable("should be handled in gl_PerVertex struct");
     default:
       llvm_unreachable(
@@ -2584,9 +2746,9 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
                                          isPrecise, srcLoc);
   }
   // According to DXIL spec, the PrimitiveID SV can only be used by PCIn, HSIn,
-  // DSIn, GSIn, GSOut, and PSIn.
+  // DSIn, GSIn, GSOut, PSIn, and MSPOut.
   // According to Vulkan spec, the PrimitiveId BuiltIn can only be used in
-  // HS/DS/PS In, GS In/Out.
+  // HS/DS/PS In, GS In/Out, MSPOut.
   case hlsl::Semantic::Kind::PrimitiveID: {
     // Translate to PrimitiveId BuiltIn for all valid SigPoints.
     stageVar->setIsSpirvBuiltin();
@@ -2666,9 +2828,9 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
     return spvBuilder.addStageBuiltinVar(type, sc, bi, isPrecise, srcLoc);
   }
   // According to DXIL spec, the RenderTargetArrayIndex SV can only be used by
-  // VSIn, VSOut, HSCPIn, HSCPOut, DSIn, DSOut, GSVIn, GSOut, PSIn.
-  // According to Vulkan spec, the Layer BuiltIn can only be used in GSOut and
-  // PSIn.
+  // VSIn, VSOut, HSCPIn, HSCPOut, DSIn, DSOut, GSVIn, GSOut, PSIn, MSPOut.
+  // According to Vulkan spec, the Layer BuiltIn can only be used in GSOut
+  // PSIn, and MSPOut.
   case hlsl::Semantic::Kind::RenderTargetArrayIndex: {
     switch (sigPointKind) {
     case hlsl::SigPoint::Kind::VSIn:
@@ -2686,6 +2848,7 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
                                            srcLoc);
     case hlsl::SigPoint::Kind::GSOut:
     case hlsl::SigPoint::Kind::PSIn:
+    case hlsl::SigPoint::Kind::MSPOut:
       stageVar->setIsSpirvBuiltin();
       return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::Layer, isPrecise,
                                            srcLoc);
@@ -2694,9 +2857,9 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
     }
   }
   // According to DXIL spec, the ViewportArrayIndex SV can only be used by
-  // VSIn, VSOut, HSCPIn, HSCPOut, DSIn, DSOut, GSVIn, GSOut, PSIn.
+  // VSIn, VSOut, HSCPIn, HSCPOut, DSIn, DSOut, GSVIn, GSOut, PSIn, MSPOut.
   // According to Vulkan spec, the ViewportIndex BuiltIn can only be used in
-  // GSOut and PSIn.
+  // GSOut, PSIn, and MSPOut.
   case hlsl::Semantic::Kind::ViewPortArrayIndex: {
     switch (sigPointKind) {
     case hlsl::SigPoint::Kind::VSIn:
@@ -2714,6 +2877,7 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
                                            isPrecise, srcLoc);
     case hlsl::SigPoint::Kind::GSOut:
     case hlsl::SigPoint::Kind::PSIn:
+    case hlsl::SigPoint::Kind::MSPOut:
       stageVar->setIsSpirvBuiltin();
       return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::ViewportIndex,
                                            isPrecise, srcLoc);
@@ -2852,6 +3016,7 @@ bool DeclResultIdMapper::validateVKBuiltins(const NamedDecl *decl,
       case hlsl::SigPoint::Kind::GSVIn:
       case hlsl::SigPoint::Kind::GSOut:
       case hlsl::SigPoint::Kind::PSIn:
+      case hlsl::SigPoint::Kind::MSOut:
         break;
       default:
         emitError("PointSize builtin cannot be used as %0", loc)
@@ -2867,9 +3032,20 @@ bool DeclResultIdMapper::validateVKBuiltins(const NamedDecl *decl,
         success = false;
       }
 
-      if (sigPoint->GetKind() != hlsl::SigPoint::Kind::VSIn) {
-        emitError("%0 builtin can only be used in vertex shader input", loc)
-            << builtin;
+      switch (sigPoint->GetKind()) {
+      case hlsl::SigPoint::Kind::VSIn:
+        break;
+      case hlsl::SigPoint::Kind::MSIn:
+      case hlsl::SigPoint::Kind::ASIn:
+        if (builtin != "DrawIndex") {
+          emitError("%0 builtin cannot be used as %1", loc)
+              << builtin << sigPoint->GetName();
+          success = false;
+        }
+        break;
+      default:
+        emitError("%0 builtin cannot be used as %1", loc)
+            << builtin << sigPoint->GetName();
         success = false;
       }
     } else if (builtin == "DeviceIndex") {
@@ -2884,6 +3060,20 @@ bool DeclResultIdMapper::validateVKBuiltins(const NamedDecl *decl,
             << builtin;
         success = false;
       }
+    } else if (builtin == "ViewportMaskNV") {
+      if (sigPoint->GetKind() != hlsl::SigPoint::Kind::MSPOut) {
+        emitError("%0 builtin can only be used as 'primitives' output in MS",
+                  loc)
+            << builtin;
+        success = false;
+      }
+      if (!declType->isArrayType() ||
+          !declType->getArrayElementTypeNoTypeQual()->isSpecificBuiltinType(
+              BuiltinType::Kind::Int)) {
+        emitError("%0 builtin must be of type array of integers", loc)
+            << builtin;
+        success = false;
+      }
     }
   }
 
@@ -2911,6 +3101,8 @@ DeclResultIdMapper::getStorageClassForSigPoint(const hlsl::SigPoint *sigPoint) {
     case hlsl::DXIL::SigPointKind::HSIn:
     case hlsl::DXIL::SigPointKind::GSIn:
     case hlsl::DXIL::SigPointKind::CSIn:
+    case hlsl::DXIL::SigPointKind::MSIn:
+    case hlsl::DXIL::SigPointKind::ASIn:
       sc = spv::StorageClass::Input;
       break;
     default:
@@ -2918,12 +3110,14 @@ DeclResultIdMapper::getStorageClassForSigPoint(const hlsl::SigPoint *sigPoint) {
     }
     break;
   }
-  case hlsl::DXIL::SignatureKind::PatchConstant: {
+  case hlsl::DXIL::SignatureKind::PatchConstOrPrim: {
     // There are some special cases in HLSL (See docs/dxil.rst):
-    // SignatureKind is "PatchConstant" for PCOut and DSIn.
+    // SignatureKind is "PatchConstOrPrim" for PCOut, MSPOut and DSIn.
     switch (sigPointKind) {
     case hlsl::DXIL::SigPointKind::PCOut:
+    case hlsl::DXIL::SigPointKind::MSPOut:
       // Patch Constant Output (Output of Hull which is passed to Domain).
+      // Mesh Shader per-primitive output attributes.
       sc = spv::StorageClass::Output;
       break;
     case hlsl::DXIL::SigPointKind::DSIn:

+ 20 - 5
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -297,6 +297,13 @@ public:
   SpirvVariable *createRayTracingNVStageVar(spv::StorageClass sc,
                                             const VarDecl *decl);
 
+  bool createPayloadStageVars(const hlsl::SigPoint *sigPoint,
+                              spv::StorageClass sc, const NamedDecl *decl,
+                              bool asInput, QualType type,
+                              const llvm::StringRef namePrefix,
+                              SpirvInstruction **value,
+                              uint32_t payloadMemOffset = 0);
+
   /// \brief Creates a function-scope paramter in the current function and
   /// returns its instruction.
   SpirvFunctionParameter *createFnParam(const ParmVarDecl *param);
@@ -479,6 +486,16 @@ public:
   bool decorateResourceBindings();
 
   bool requiresLegalization() const { return needsLegalization; }
+ 
+  /// \brief Returns the given decl's HLSL semantic information.
+  static SemanticInfo getStageVarSemantic(const NamedDecl *decl);
+
+  /// \brief Returns SPIR-V instruction for given stage var decl.
+  SpirvInstruction *getStageVarInstruction(const DeclaratorDecl *decl) {
+    auto *value = stageVarInstructions.lookup(decl);
+    assert(value);
+    return value;
+  }
 
 private:
   /// \brief Wrapper method to create a fatal error message and report it
@@ -557,9 +574,6 @@ private:
       const DeclContext *decl, int arraySize, ContextUsageKind usageKind,
       llvm::StringRef typeName, llvm::StringRef varName);
 
-  /// Returns the given decl's HLSL semantic information.
-  static SemanticInfo getStageVarSemantic(const NamedDecl *decl);
-
   /// Creates all the stage variables mapped from semantics on the given decl.
   /// Returns true on sucess.
   ///
@@ -774,8 +788,9 @@ DeclResultIdMapper::DeclResultIdMapper(ASTContext &context,
       glPerVertex(context, spirvContext, spirvBuilder) {}
 
 bool DeclResultIdMapper::decorateStageIOLocations() {
-  if (spvContext.isRay()) {
-    // No location assignment for any raytracing stage variables
+  if (spvContext.isRay() || spvContext.isAS()) {
+    // No location assignment for any raytracing stage variables or
+    // amplification shader variables
     return true;
   }
   // Try both input and output even if input location assignment failed

+ 3 - 0
tools/clang/lib/SPIRV/FeatureManager.cpp

@@ -119,6 +119,7 @@ Extension FeatureManager::getExtensionSymbol(llvm::StringRef name) {
             Extension::GOOGLE_user_type)
       .Case("SPV_KHR_post_depth_coverage", Extension::KHR_post_depth_coverage)
       .Case("SPV_NV_ray_tracing", Extension::NV_ray_tracing)
+      .Case("SPV_NV_mesh_shader", Extension::NV_mesh_shader)
       .Default(Extension::Unknown);
 }
 
@@ -156,6 +157,8 @@ const char *FeatureManager::getExtensionName(Extension symbol) {
     return "SPV_GOOGLE_user_type";
   case Extension::NV_ray_tracing:
     return "SPV_NV_ray_tracing";
+  case Extension::NV_mesh_shader:
+    return "SPV_NV_mesh_shader";
   default:
     break;
   }

+ 70 - 41
tools/clang/lib/SPIRV/GlPerVertex.cpp

@@ -117,6 +117,12 @@ bool GlPerVertex::recordGlPerVertexDeclFacts(const DeclaratorDecl *decl,
   if (type->isVoidType())
     return true;
 
+  // Indices or payload mesh shader param objects don't contain any
+  // builtin variables or semantic strings. So early return.
+  if (decl->hasAttr<HLSLIndicesAttr>() || decl->hasAttr<HLSLPayloadAttr>()) {
+    return true;
+  }
+
   return doGlPerVertexFacts(decl, type, asInput);
 }
 
@@ -146,18 +152,16 @@ bool GlPerVertex::doGlPerVertexFacts(const DeclaratorDecl *decl,
       return doGlPerVertexFacts(
           decl, hlsl::GetHLSLInputPatchElementType(baseType), asInput);
     }
-    if (hlsl::IsHLSLOutputPatchType(baseType)) {
-      return doGlPerVertexFacts(
-          decl, hlsl::GetHLSLOutputPatchElementType(baseType), asInput);
-    }
-
-    if (hlsl::IsHLSLStreamOutputType(baseType)) {
+    if (hlsl::IsHLSLOutputPatchType(baseType) ||
+        hlsl::IsHLSLStreamOutputType(baseType)) {
       return doGlPerVertexFacts(
           decl, hlsl::GetHLSLOutputPatchElementType(baseType), asInput);
     }
-    if (hasGSPrimitiveTypeQualifier(decl)) {
-      // GS inputs have an additional arrayness that we should remove to check
-      // the underlying type instead.
+    if (hasGSPrimitiveTypeQualifier(decl) ||
+        decl->hasAttr<HLSLVerticesAttr>() ||
+        decl->hasAttr<HLSLPrimitivesAttr>()) {
+      // GS inputs and MS output attribute have an additional arrayness that we
+      // should remove to check the underlying type instead.
       baseType = astContext.getAsConstantArrayType(baseType)->getElementType();
       return doGlPerVertexFacts(decl, baseType, asInput);
     }
@@ -169,7 +173,7 @@ bool GlPerVertex::doGlPerVertexFacts(const DeclaratorDecl *decl,
     return false;
   }
 
-  // Semantic string is attched to this decl directly
+  // Semantic string is attached to this decl directly
 
   // Select the corresponding data member to update
   SemanticIndexToTypeMap *typeMap = nullptr;
@@ -235,7 +239,8 @@ bool GlPerVertex::doGlPerVertexFacts(const DeclaratorDecl *decl,
   }
 
   if (baseType->isConstantArrayType()) {
-    if (spvContext.isHS() || spvContext.isDS() || spvContext.isGS()) {
+    if (spvContext.isHS() || spvContext.isDS() || spvContext.isGS() ||
+        spvContext.isMS()) {
       // Ignore the outermost arrayness and check the inner type to be
       // (vector of) floats
 
@@ -355,11 +360,14 @@ bool GlPerVertex::tryToAccess(hlsl::SigPoint::Kind sigPointKind,
                               uint32_t semanticIndex,
                               llvm::Optional<SpirvInstruction *> invocationId,
                               SpirvInstruction **value, bool noWriteBack,
+                              SpirvInstruction *vecComponent,
                               SourceLocation loc) {
   assert(value);
   // invocationId should only be used for HSPCOut.
-  assert(invocationId.hasValue() ? sigPointKind == hlsl::SigPoint::Kind::HSCPOut
-                                 : true);
+  assert(invocationId.hasValue()
+             ? (sigPointKind == hlsl::SigPoint::Kind::HSCPOut ||
+                sigPointKind == hlsl::SigPoint::Kind::MSOut)
+             : true);
 
   switch (semanticKind) {
   case hlsl::Semantic::Kind::ClipDistance:
@@ -381,10 +389,12 @@ bool GlPerVertex::tryToAccess(hlsl::SigPoint::Kind sigPointKind,
   case hlsl::SigPoint::Kind::VSOut:
   case hlsl::SigPoint::Kind::HSCPOut:
   case hlsl::SigPoint::Kind::DSOut:
+  case hlsl::SigPoint::Kind::MSOut:
     if (noWriteBack)
       return true;
 
-    return writeField(semanticKind, semanticIndex, invocationId, value, loc);
+    return writeField(semanticKind, semanticIndex, invocationId, value,
+                      vecComponent, loc);
   default:
     // Only interfaces that involve gl_PerVertex are needed.
     break;
@@ -526,7 +536,7 @@ bool GlPerVertex::readField(hlsl::Semantic::Kind semanticKind,
 
 void GlPerVertex::writeClipCullArrayFromType(
     llvm::Optional<SpirvInstruction *> invocationId, bool isClip,
-    uint32_t offset, QualType fromType, SpirvInstruction *fromValue,
+    SpirvInstruction *offset, QualType fromType, SpirvInstruction *fromValue,
     SourceLocation loc) const {
   auto *clipCullVar = isClip ? outClipVar : outCullVar;
 
@@ -542,10 +552,8 @@ void GlPerVertex::writeClipCullArrayFromType(
     uint32_t count = {};
 
     if (isScalarType(fromType)) {
-      auto *constant = spvBuilder.getConstantInt(astContext.UnsignedIntTy,
-                                                 llvm::APInt(32, offset));
       auto *ptr =
-          spvBuilder.createAccessChain(f32Type, clipCullVar, {constant}, loc);
+          spvBuilder.createAccessChain(f32Type, clipCullVar, {offset}, loc);
       spvBuilder.createStore(ptr, fromValue, loc);
       return;
     }
@@ -556,9 +564,13 @@ void GlPerVertex::writeClipCullArrayFromType(
       for (uint32_t i = 0; i < count; ++i) {
         // Write elements sequentially into the float array
         auto *constant = spvBuilder.getConstantInt(astContext.UnsignedIntTy,
-                                                   llvm::APInt(32, offset + i));
-        auto *ptr =
-            spvBuilder.createAccessChain(f32Type, clipCullVar, {constant}, loc);
+                                                   llvm::APInt(32, i));
+        auto *ptr = spvBuilder.createAccessChain(
+            f32Type, clipCullVar,
+            {spvBuilder.createBinaryOp(spv::Op::OpIAdd,
+                                       astContext.UnsignedIntTy, offset,
+                                       constant, loc)},
+            loc);
         auto *subValue =
             spvBuilder.createCompositeExtract(f32Type, fromValue, {i}, loc);
         spvBuilder.createStore(ptr, subValue, loc);
@@ -571,8 +583,8 @@ void GlPerVertex::writeClipCullArrayFromType(
     return;
   }
 
-  // Writing to an array only happens in HSCPOut.
-  assert(spvContext.isHS());
+  // Writing to an array only happens in HSCPOut or MSOut.
+  assert(spvContext.isHS() || spvContext.isMS());
   // And we are only writing to the array element with InvocationId as index.
   assert(invocationId.hasValue());
 
@@ -588,11 +600,8 @@ void GlPerVertex::writeClipCullArrayFromType(
   uint32_t count = {};
 
   if (isScalarType(fromType)) {
-    auto *ptr = spvBuilder.createAccessChain(
-        f32Type, clipCullVar,
-        {arrayIndex, spvBuilder.getConstantInt(astContext.UnsignedIntTy,
-                                               llvm::APInt(32, offset))},
-        loc);
+    auto *ptr = spvBuilder.createAccessChain(f32Type, clipCullVar,
+                                             {arrayIndex, offset}, loc);
     spvBuilder.createStore(ptr, fromValue, loc);
     return;
   }
@@ -605,8 +614,11 @@ void GlPerVertex::writeClipCullArrayFromType(
           // Block array index
           {arrayIndex,
            // Write elements sequentially into the float array
-           spvBuilder.getConstantInt(astContext.UnsignedIntTy,
-                                     llvm::APInt(32, offset + i))},
+           spvBuilder.createBinaryOp(
+               spv::Op::OpIAdd, astContext.UnsignedIntTy, offset,
+               spvBuilder.getConstantInt(astContext.UnsignedIntTy,
+                                         llvm::APInt(32, i)),
+               loc)},
           loc);
 
       auto *subValue =
@@ -623,7 +635,9 @@ void GlPerVertex::writeClipCullArrayFromType(
 bool GlPerVertex::writeField(hlsl::Semantic::Kind semanticKind,
                              uint32_t semanticIndex,
                              llvm::Optional<SpirvInstruction *> invocationId,
-                             SpirvInstruction **value, SourceLocation loc) {
+                             SpirvInstruction **value,
+                             SpirvInstruction *vecComponent,
+                             SourceLocation loc) {
   // Similar to the writing logic in DeclResultIdMapper::createStageVars():
   //
   // Unlike reading, which may require us to read stand-alone builtins and
@@ -636,9 +650,13 @@ bool GlPerVertex::writeField(hlsl::Semantic::Kind semanticKind,
   //       but we only write to the struct at the InvocationID index
   // * DS: output is a single struct, without extra arrayness
   // * GS: output is controlled by OpEmitVertex, one vertex per time
+  // * MS: output is an array of structs, with extra arrayness
   //
   // The interesting shader stage is HS. We need the InvocationID to write
   // out the value to the correct array element.
+  SpirvInstruction *offset = nullptr;
+  QualType type;
+  bool isClip = false;
   switch (semanticKind) {
   case hlsl::Semantic::Kind::ClipDistance: {
     const auto offsetIter = outClipOffset.find(semanticIndex);
@@ -646,10 +664,11 @@ bool GlPerVertex::writeField(hlsl::Semantic::Kind semanticKind,
     // We should have recorded all these semantics before.
     assert(offsetIter != outClipOffset.end());
     assert(typeIter != outClipType.end());
-    writeClipCullArrayFromType(invocationId, /*isClip=*/true,
-                               offsetIter->second, typeIter->second, *value,
-                               loc);
-    return true;
+    offset = spvBuilder.getConstantInt(astContext.UnsignedIntTy,
+                                       llvm::APInt(32, offsetIter->second));
+    type = typeIter->second;
+    isClip = true;
+    break;
   }
   case hlsl::Semantic::Kind::CullDistance: {
     const auto offsetIter = outCullOffset.find(semanticIndex);
@@ -657,16 +676,26 @@ bool GlPerVertex::writeField(hlsl::Semantic::Kind semanticKind,
     // We should have recorded all these semantics before.
     assert(offsetIter != outCullOffset.end());
     assert(typeIter != outCullType.end());
-    writeClipCullArrayFromType(invocationId, /*isClip=*/false,
-                               offsetIter->second, typeIter->second, *value,
-                               loc);
-    return true;
+    offset = spvBuilder.getConstantInt(astContext.UnsignedIntTy,
+                                       llvm::APInt(32, offsetIter->second));
+    type = typeIter->second;
+    break;
   }
   default:
     // Only Cull or Clip apply.
-    break;
+    return false;
   }
-  return false;
+  if (vecComponent) {
+    QualType elemType;
+    if (!isVectorType(type, &elemType)) {
+      assert(false && "expected vector type");
+    }
+    type = elemType;
+    offset = spvBuilder.createBinaryOp(
+        spv::Op::OpIAdd, astContext.UnsignedIntTy, vecComponent, offset, loc);
+  }
+  writeClipCullArrayFromType(invocationId, isClip, offset, type, *value, loc);
+  return true;
 }
 
 } // end namespace spirv

+ 5 - 4
tools/clang/lib/SPIRV/GlPerVertex.h

@@ -84,7 +84,7 @@ public:
                    uint32_t semanticIndex,
                    llvm::Optional<SpirvInstruction *> invocation,
                    SpirvInstruction **value, bool noWriteBack,
-                   SourceLocation loc);
+                   SpirvInstruction *vecComponent, SourceLocation loc);
 
 private:
   template <unsigned N>
@@ -114,13 +114,14 @@ private:
   /// generated to make sure type correctness.
   void
   writeClipCullArrayFromType(llvm::Optional<SpirvInstruction *> invocationId,
-                             bool isClip, uint32_t offset, QualType fromType,
-                             SpirvInstruction *fromValue,
+                             bool isClip, SpirvInstruction *offset,
+                             QualType fromType, SpirvInstruction *fromValue,
                              SourceLocation loc) const;
   /// Creates SPIR-V instructions to write a field in gl_PerVertex.
   bool writeField(hlsl::Semantic::Kind semanticKind, uint32_t semanticIndex,
                   llvm::Optional<SpirvInstruction *> invocationId,
-                  SpirvInstruction **value, SourceLocation loc);
+                  SpirvInstruction **value, SpirvInstruction *vecComponent,
+                  SourceLocation loc);
 
   /// Internal implementation for recordClipCullDistanceDecl().
   bool doGlPerVertexFacts(const DeclaratorDecl *decl, QualType type,

+ 17 - 0
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -965,6 +965,23 @@ void SpirvBuilder::decorateNoContraction(SpirvInstruction *target,
   module->addDecoration(decor);
 }
 
+void SpirvBuilder::decoratePerPrimitiveNV(SpirvInstruction *target,
+                                          SourceLocation srcLoc) {
+  auto *decor = new (context)
+      SpirvDecoration(srcLoc, target, spv::Decoration::PerPrimitiveNV);
+  module->addDecoration(decor);
+}
+
+void SpirvBuilder::decoratePerTaskNV(SpirvInstruction *target, uint32_t offset,
+                                     SourceLocation srcLoc) {
+  auto *decor =
+      new (context) SpirvDecoration(srcLoc, target, spv::Decoration::PerTaskNV);
+  module->addDecoration(decor);
+  decor = new (context)
+      SpirvDecoration(srcLoc, target, spv::Decoration::Offset, {offset});
+  module->addDecoration(decor);
+}
+
 SpirvConstant *SpirvBuilder::getConstantInt(QualType type, llvm::APInt value,
                                             bool specConst) {
   // We do not reuse existing constant integers. Just create a new one.

+ 532 - 43
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -512,18 +512,22 @@ SpirvEmitter::SpirvEmitter(CompilerInstance &ci)
     spirvOptions.cBufferLayoutRule = SpirvLayoutRule::FxcCTBuffer;
     spirvOptions.tBufferLayoutRule = SpirvLayoutRule::FxcCTBuffer;
     spirvOptions.sBufferLayoutRule = SpirvLayoutRule::FxcSBuffer;
+    spirvOptions.ampPayloadLayoutRule = SpirvLayoutRule::FxcSBuffer;
   } else if (spirvOptions.useGlLayout) {
     spirvOptions.cBufferLayoutRule = SpirvLayoutRule::GLSLStd140;
     spirvOptions.tBufferLayoutRule = SpirvLayoutRule::GLSLStd430;
     spirvOptions.sBufferLayoutRule = SpirvLayoutRule::GLSLStd430;
+    spirvOptions.ampPayloadLayoutRule = SpirvLayoutRule::GLSLStd430;
   } else if (spirvOptions.useScalarLayout) {
     spirvOptions.cBufferLayoutRule = SpirvLayoutRule::Scalar;
     spirvOptions.tBufferLayoutRule = SpirvLayoutRule::Scalar;
     spirvOptions.sBufferLayoutRule = SpirvLayoutRule::Scalar;
+    spirvOptions.ampPayloadLayoutRule = SpirvLayoutRule::Scalar;
   } else {
     spirvOptions.cBufferLayoutRule = SpirvLayoutRule::RelaxedGLSLStd140;
     spirvOptions.tBufferLayoutRule = SpirvLayoutRule::RelaxedGLSLStd430;
     spirvOptions.sBufferLayoutRule = SpirvLayoutRule::RelaxedGLSLStd430;
+    spirvOptions.ampPayloadLayoutRule = SpirvLayoutRule::RelaxedGLSLStd430;
   }
 
   // Set shader module version, source file name, and source file content (if
@@ -5005,6 +5009,11 @@ SpirvEmitter::processAssignment(const Expr *lhs, SpirvInstruction *rhs,
   if (SpirvInstruction *result = tryToAssignToRWBufferRWTexture(lhs, rhs))
     return result;
 
+  // Assigning to a out attribute or indices object in mesh shader should be
+  // handled differently.
+  if (SpirvInstruction *result = tryToAssignToMSOutAttrsOrIndices(lhs, rhs))
+    return result;
+
   // Normal assignment procedure
 
   if (!lhsPtr)
@@ -5849,6 +5858,11 @@ SpirvEmitter::tryToAssignToVectorElements(const Expr *lhs,
       (void)result;
       return rhs; // TODO: incorrect for compound assignments
     } else {
+      // Assigning to one component of mesh out attribute/indices vector object.
+      SpirvInstruction *vecComponent = spvBuilder.getConstantInt(
+          astContext.UnsignedIntTy, llvm::APInt(32, accessor.Swz0));
+      if (tryToAssignToMSOutAttrsOrIndices(base, rhs, vecComponent))
+        return rhs;
       // Assigning to one normal vector component. Nothing special, just fall
       // back to the normal CodeGen path.
       return nullptr;
@@ -5870,6 +5884,26 @@ SpirvEmitter::tryToAssignToVectorElements(const Expr *lhs,
     return processAssignment(base, rhs, false);
   }
 
+  if (tryToAssignToMSOutAttrsOrIndices(base, rhs, /*vecComponent=*/nullptr,
+                                       /*noWriteBack=*/true)) {
+    // Assigning to 'n' components of mesh out attribute/indices vector object.
+    const QualType elemType =
+        hlsl::GetHLSLVecElementType(rhs->getAstResultType());
+    uint32_t i = 0;
+    for (; i < accessor.Count; ++i) {
+      auto *rhsElem = spvBuilder.createCompositeExtract(elemType, rhs, {i},
+                                                        lhs->getLocStart());
+      uint32_t position;
+      accessor.GetPosition(i, &position);
+      SpirvInstruction *vecComponent = spvBuilder.getConstantInt(
+          astContext.UnsignedIntTy, llvm::APInt(32, position));
+      if (!tryToAssignToMSOutAttrsOrIndices(base, rhsElem, vecComponent))
+        break;
+    }
+    assert(i == accessor.Count);
+    return rhs;
+  }
+
   llvm::SmallVector<uint32_t, 4> selectors;
   selectors.resize(baseSize);
   // Assume we are selecting all original elements first.
@@ -5985,6 +6019,194 @@ SpirvEmitter::tryToAssignToMatrixElements(const Expr *lhs,
   return rhs;
 }
 
+SpirvInstruction *SpirvEmitter::tryToAssignToMSOutAttrsOrIndices(
+    const Expr *lhs, SpirvInstruction *rhs, SpirvInstruction *vecComponent,
+    bool noWriteBack) {
+  // Early exit for non-mesh shaders.
+  if (!spvContext.isMS())
+    return nullptr;
+
+  llvm::SmallVector<SpirvInstruction *, 4> indices;
+  bool isMSOutAttribute = false;
+  bool isMSOutAttributeBlock = false;
+  bool isMSOutIndices = false;
+
+  const Expr *base = collectArrayStructIndices(lhs, /*rawIndex*/ false,
+                                               /*rawIndices*/ nullptr, &indices,
+                                               &isMSOutAttribute);
+  // Expecting at least one array index - early exit.
+  if (!base || indices.empty())
+    return nullptr;
+
+  const DeclaratorDecl *varDecl = nullptr;
+  if (isMSOutAttribute) {
+    const MemberExpr *memberExpr = dyn_cast<MemberExpr>(base);
+    assert(memberExpr);
+    varDecl = cast<DeclaratorDecl>(memberExpr->getMemberDecl());
+  } else {
+    if (const auto *arg = dyn_cast<DeclRefExpr>(base)) {
+      if (varDecl = dyn_cast<DeclaratorDecl>(arg->getDecl())) {
+        if (varDecl->hasAttr<HLSLIndicesAttr>()) {
+          isMSOutIndices = true;
+        } else if (varDecl->hasAttr<HLSLVerticesAttr>() ||
+                   varDecl->hasAttr<HLSLPrimitivesAttr>()) {
+          isMSOutAttributeBlock = true;
+        }
+      }
+    }
+  }
+
+  // Return if no out attribute or indices object found.
+  if (!(isMSOutAttribute || isMSOutAttributeBlock || isMSOutIndices)) {
+    return nullptr;
+  }
+
+  // For noWriteBack, return without generating write instructions.
+  if (noWriteBack) {
+    return rhs;
+  }
+
+  // Add vecComponent to indices.
+  if (vecComponent) {
+    indices.push_back(vecComponent);
+  }
+
+  if (isMSOutAttribute) {
+    assignToMSOutAttribute(varDecl, rhs, indices);
+  } else if (isMSOutIndices) {
+    assignToMSOutIndices(varDecl, rhs, indices);
+  } else {
+    assert(isMSOutAttributeBlock);
+    QualType type = varDecl->getType();
+    assert(isa<ConstantArrayType>(type));
+    type = astContext.getAsConstantArrayType(type)->getElementType();
+    assert(type->isStructureType());
+
+    // Extract subvalue and assign to its corresponding member attribute.
+    const auto *structDecl = type->getAs<RecordType>()->getDecl();
+    for (const auto *field : structDecl->fields()) {
+      const auto fieldType = field->getType();
+      SpirvInstruction *subValue = spvBuilder.createCompositeExtract(
+          fieldType, rhs, {getNumBaseClasses(type) + field->getFieldIndex()},
+          lhs->getLocStart());
+      assignToMSOutAttribute(field, subValue, indices);
+    }
+  }
+
+  // TODO: OK, this return value is incorrect for compound assignments, for
+  // which cases we should return lvalues. Should at least emit errors if
+  // this return value is used (can be checked via ASTContext.getParents).
+  return rhs;
+}
+
+void SpirvEmitter::assignToMSOutAttribute(
+    const DeclaratorDecl *decl, SpirvInstruction *value,
+    const llvm::SmallVector<SpirvInstruction *, 4> &indices) {
+  assert(spvContext.isMS() && !indices.empty());
+
+  // Extract attribute index and vecComponent (if any).
+  SpirvInstruction *attrIndex = indices.front();
+  SpirvInstruction *vecComponent = nullptr;
+  if (indices.size() > 1) {
+    vecComponent = indices.back();
+  }
+
+  auto semanticInfo = declIdMapper.getStageVarSemantic(decl);
+  assert(semanticInfo.isValid());
+  const auto loc = decl->getLocation();
+  // Special handle writes to clip/cull distance attributes.
+  if (!declIdMapper.glPerVertex.tryToAccess(
+          hlsl::DXIL::SigPointKind::MSOut, semanticInfo.semantic->GetKind(),
+          semanticInfo.index, attrIndex, &value, /*noWriteBack=*/false,
+          vecComponent, loc)) {
+    // All other attribute writes are handled below.
+    auto *varInstr = declIdMapper.getStageVarInstruction(decl);
+    QualType valueType = value->getAstResultType();
+    varInstr = spvBuilder.createAccessChain(valueType, varInstr, indices, loc);
+    spvBuilder.createStore(varInstr, value, loc);
+  }
+}
+
+void SpirvEmitter::assignToMSOutIndices(
+    const DeclaratorDecl *decl, SpirvInstruction *value,
+    const llvm::SmallVector<SpirvInstruction *, 4> &indices) {
+  assert(spvContext.isMS() && !indices.empty());
+
+  // Extract vertex index and vecComponent (if any).
+  SpirvInstruction *vertIndex = indices.front();
+  SpirvInstruction *vecComponent = nullptr;
+  if (indices.size() > 1) {
+    vecComponent = indices.back();
+  }
+  auto *var = declIdMapper.getStageVarInstruction(decl);
+  const auto *varTypeDecl = astContext.getAsConstantArrayType(decl->getType());
+  QualType varType = varTypeDecl->getElementType();
+  uint32_t numVertices = 1;
+  if (!isVectorType(varType, nullptr, &numVertices)) {
+    assert(isScalarType(varType));
+  }
+  QualType valueType = value->getAstResultType();
+  uint32_t numValues = 1;
+  if (!isVectorType(valueType, nullptr, &numValues)) {
+    assert(isScalarType(valueType));
+  }
+
+  const auto loc = decl->getLocation();
+  if (numVertices == 1) {
+    // for "point" output topology.
+    assert(numValues == 1);
+    // create accesschain for PrimitiveIndicesNV[vertIndex].
+    auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, var,
+                                             {vertIndex}, loc);
+    // finally create store for PrimitiveIndicesNV[vertIndex] = value.
+    spvBuilder.createStore(ptr, value, loc);
+  } else {
+    // for "line" or "triangle" output topology.
+    assert(numVertices == 2 || numVertices == 3);
+    // set baseOffset = vertIndex * numVertices.
+    auto *baseOffset = spvBuilder.createBinaryOp(
+        spv::Op::OpIMul, astContext.UnsignedIntTy, vertIndex,
+        spvBuilder.getConstantInt(astContext.UnsignedIntTy,
+                                  llvm::APInt(32, numVertices)),
+        loc);
+    if (vecComponent) {
+      // write an individual vector component of uint2 or uint3.
+      assert(numValues == 1);
+      // set baseOffset = baseOffset + vecComponent.
+      baseOffset =
+          spvBuilder.createBinaryOp(spv::Op::OpIAdd, astContext.UnsignedIntTy,
+                                    baseOffset, vecComponent, loc);
+      // create accesschain for PrimitiveIndicesNV[baseOffset].
+      auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, var,
+                                               {baseOffset}, loc);
+      // finally create store for PrimitiveIndicesNV[baseOffset] = value.
+      spvBuilder.createStore(ptr, value, loc);
+    } else {
+      // write all vector components of uint2 or uint3.
+      assert(numValues == numVertices);
+      auto *curOffset = baseOffset;
+      for (uint32_t i = 0; i < numValues; ++i) {
+        if (i != 0) {
+          // set curOffset = baseOffset + i.
+          curOffset = spvBuilder.createBinaryOp(
+              spv::Op::OpIAdd, astContext.UnsignedIntTy, baseOffset,
+              spvBuilder.getConstantInt(astContext.UnsignedIntTy,
+                                        llvm::APInt(32, i)),
+              loc);
+        }
+        // create accesschain for PrimitiveIndicesNV[curOffset].
+        auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, var,
+                                                 {curOffset}, loc);
+        // finally create store for PrimitiveIndicesNV[curOffset] = value[i].
+        spvBuilder.createStore(ptr,
+                               spvBuilder.createCompositeExtract(
+                                   astContext.UnsignedIntTy, value, {i}, loc),
+                               loc);
+      }
+    }
+  }
+}
+
 SpirvInstruction *SpirvEmitter::processEachVectorInMatrix(
     const Expr *matrix, SpirvInstruction *matrixVal,
     llvm::function_ref<SpirvInstruction *(uint32_t, QualType,
@@ -6141,7 +6363,8 @@ SpirvEmitter::processMatrixBinaryOp(const Expr *lhs, const Expr *rhs,
 const Expr *SpirvEmitter::collectArrayStructIndices(
     const Expr *expr, bool rawIndex,
     llvm::SmallVectorImpl<uint32_t> *rawIndices,
-    llvm::SmallVectorImpl<SpirvInstruction *> *indices) {
+    llvm::SmallVectorImpl<SpirvInstruction *> *indices,
+    bool *isMSOutAttribute) {
   assert((rawIndex && rawIndices) || (!rawIndex && indices));
 
   if (const auto *indexing = dyn_cast<MemberExpr>(expr)) {
@@ -6156,7 +6379,20 @@ const Expr *SpirvEmitter::collectArrayStructIndices(
 
     const Expr *base = collectArrayStructIndices(
         indexing->getBase()->IgnoreParenNoopCasts(astContext), rawIndex,
-        rawIndices, indices);
+        rawIndices, indices, isMSOutAttribute);
+
+    if (isMSOutAttribute && base) {
+      if (const auto *arg = dyn_cast<DeclRefExpr>(base)) {
+        if (const auto *varDecl = dyn_cast<VarDecl>(arg->getDecl())) {
+          if (varDecl->hasAttr<HLSLVerticesAttr>() ||
+              varDecl->hasAttr<HLSLPrimitivesAttr>()) {
+            assert(spvContext.isMS());
+            *isMSOutAttribute = true;
+            return expr;
+          }
+        }
+      }
+    }
 
     // Append the index of the current level
     const auto *fieldDecl = cast<FieldDecl>(indexing->getMemberDecl());
@@ -6183,8 +6419,8 @@ const Expr *SpirvEmitter::collectArrayStructIndices(
     // The base of an ArraySubscriptExpr has a wrapping LValueToRValue implicit
     // cast. We need to ingore it to avoid creating OpLoad.
     const Expr *thisBase = indexing->getBase()->IgnoreParenLValueCasts();
-    const Expr *base =
-        collectArrayStructIndices(thisBase, rawIndex, rawIndices, indices);
+    const Expr *base = collectArrayStructIndices(thisBase, rawIndex, rawIndices,
+                                                 indices, isMSOutAttribute);
     indices->push_back(doExpr(indexing->getIdx()));
     return base;
   }
@@ -6204,8 +6440,8 @@ const Expr *SpirvEmitter::collectArrayStructIndices(
           indexing->getArg(0)->IgnoreParenNoopCasts(astContext);
 
       const auto thisBaseType = thisBase->getType();
-      const Expr *base =
-          collectArrayStructIndices(thisBase, rawIndex, rawIndices, indices);
+      const Expr *base = collectArrayStructIndices(
+          thisBase, rawIndex, rawIndices, indices, isMSOutAttribute);
 
       if (thisBaseType != base->getType() &&
           isAKindOfStructuredOrByteBuffer(thisBaseType)) {
@@ -6857,6 +7093,14 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
   case hlsl::IntrinsicOp::IOP_CallShader: {
     processCallShader(callExpr);
     break;
+  }
+  case hlsl::IntrinsicOp::IOP_DispatchMesh: {
+    processDispatchMesh(callExpr);
+    break;
+  }
+  case hlsl::IntrinsicOp::IOP_SetMeshOutputCounts: {
+    processMeshOutputCounts(callExpr);
+    break;
   }
     INTRINSIC_SPIRV_OP_CASE(ddx, DPdx, true);
     INTRINSIC_SPIRV_OP_CASE(ddx_coarse, DPdxCoarse, false);
@@ -9188,6 +9432,7 @@ SpirvInstruction *SpirvEmitter::processReportHit(const CallExpr *callExpr) {
                                           astContext.BoolTy, reportHitArgs,
                                           callExpr->getExprLoc());
 }
+
 void SpirvEmitter::processCallShader(const CallExpr *callExpr) {
   SpirvInstruction *callDataLocInst = nullptr;
   SpirvInstruction *callDataStageVar = nullptr;
@@ -9254,11 +9499,12 @@ void SpirvEmitter::processCallShader(const CallExpr *callExpr) {
   spvBuilder.createStore(callDataArgInst, tempLoad, callExpr->getExprLoc());
   return;
 }
+
 void SpirvEmitter::processTraceRay(const CallExpr *callExpr) {
-  SpirvInstruction *payloadLocInst = nullptr;
-  SpirvInstruction *payloadStageVar = nullptr;
-  const VarDecl *payloadArg = nullptr;
-  QualType payloadType;
+  SpirvInstruction *rayPayloadLocInst = nullptr;
+  SpirvInstruction *rayPayloadStageVar = nullptr;
+  const VarDecl *rayPayloadArg = nullptr;
+  QualType rayPayloadType;
 
   const auto args = callExpr->getArgs();
 
@@ -9268,7 +9514,7 @@ void SpirvEmitter::processTraceRay(const CallExpr *callExpr) {
   }
 
   // HLSL Func
-  // template<typename Payload>
+  // template<typename RayPayload>
   // void TraceRay(RaytracingAccelerationStructure rs,
   //              uint rayflags,
   //              uint InstanceInclusionMask
@@ -9276,36 +9522,36 @@ void SpirvEmitter::processTraceRay(const CallExpr *callExpr) {
   //              uint MultiplierForGeometryContributionToHitGroupIndex,
   //              uint MissShaderIndex,
   //              RayDesc ray,
-  //              inout Payload p)
+  //              inout RayPayload p)
   // where RayDesc = {float3 origin, float tMin, float3 direction, float tMax}
 
   if (const auto *implCastExpr = dyn_cast<CastExpr>(args[7])) {
     if (const auto *arg = dyn_cast<DeclRefExpr>(implCastExpr->getSubExpr())) {
       if (const auto *varDecl = dyn_cast<VarDecl>(arg->getDecl())) {
-        payloadType = varDecl->getType();
-        payloadArg = varDecl;
-        const auto payloadPair = payloadMap.find(payloadType);
-        // Check if same type of payload stage variable was already
+        rayPayloadType = varDecl->getType();
+        rayPayloadArg = varDecl;
+        const auto rayPayloadPair = rayPayloadMap.find(rayPayloadType);
+        // Check if same type of rayPayload stage variable was already
         // created, if so re-use
-        if (payloadPair == payloadMap.end()) {
-          int numPayloadVars = payloadMap.size();
-          payloadStageVar = declIdMapper.createRayTracingNVStageVar(
+        if (rayPayloadPair == rayPayloadMap.end()) {
+          int numPayloadVars = rayPayloadMap.size();
+          rayPayloadStageVar = declIdMapper.createRayTracingNVStageVar(
               spv::StorageClass::RayPayloadNV, varDecl);
           // Decorate unique location id for each created stage var
-          spvBuilder.decorateLocation(payloadStageVar, numPayloadVars);
-          payloadLocInst = spvBuilder.getConstantInt(
+          spvBuilder.decorateLocation(rayPayloadStageVar, numPayloadVars);
+          rayPayloadLocInst = spvBuilder.getConstantInt(
               astContext.UnsignedIntTy, llvm::APInt(32, numPayloadVars));
-          payloadMap[payloadType] =
-              std::make_pair(payloadStageVar, payloadLocInst);
+          rayPayloadMap[rayPayloadType] =
+              std::make_pair(rayPayloadStageVar, rayPayloadLocInst);
         } else {
-          payloadStageVar = payloadPair->second.first;
-          payloadLocInst = payloadPair->second.second;
+          rayPayloadStageVar = rayPayloadPair->second.first;
+          rayPayloadLocInst = rayPayloadPair->second.second;
         }
       }
     }
   }
 
-  assert(payloadStageVar && payloadArg);
+  assert(rayPayloadStageVar && rayPayloadArg);
 
   const auto floatType = astContext.FloatTy;
   const auto vecType = astContext.getExtVectorType(astContext.FloatTy, 3);
@@ -9323,11 +9569,12 @@ void SpirvEmitter::processTraceRay(const CallExpr *callExpr) {
       spvBuilder.createCompositeExtract(floatType, rayDescArg, {3}, loc);
 
   // Copy argument to stage variable
-  const auto payloadArgInst =
-      declIdMapper.getDeclEvalInfo(payloadArg, payloadArg->getLocStart());
-  auto tempLoad = spvBuilder.createLoad(payloadArg->getType(), payloadArgInst,
-                                        payloadArg->getLocStart());
-  spvBuilder.createStore(payloadStageVar, tempLoad, callExpr->getExprLoc());
+  const auto rayPayloadArgInst =
+      declIdMapper.getDeclEvalInfo(rayPayloadArg, rayPayloadArg->getLocStart());
+  auto tempLoad =
+      spvBuilder.createLoad(rayPayloadArg->getType(), rayPayloadArgInst,
+                            rayPayloadArg->getLocStart());
+  spvBuilder.createStore(rayPayloadStageVar, tempLoad, callExpr->getExprLoc());
 
   // SPIR-V Instruction
   // void OpTraceNV ( <id> AccelerationStructureNV acStruct,
@@ -9340,7 +9587,7 @@ void SpirvEmitter::processTraceRay(const CallExpr *callExpr) {
   //                 <id> float Ray Tmin,
   //                 <id> vec3 Ray Direction,
   //                 <id> float Ray Tmax,
-  //                 <id> uint Payload number)
+  //                 <id> uint RayPayload number)
 
   llvm::SmallVector<SpirvInstruction *, 8> traceArgs;
   for (int ii = 0; ii < 6; ii++) {
@@ -9351,18 +9598,79 @@ void SpirvEmitter::processTraceRay(const CallExpr *callExpr) {
   traceArgs.push_back(tMin);
   traceArgs.push_back(direction);
   traceArgs.push_back(tMax);
-  traceArgs.push_back(payloadLocInst);
+  traceArgs.push_back(rayPayloadLocInst);
 
   spvBuilder.createRayTracingOpsNV(spv::Op::OpTraceNV, QualType(), traceArgs,
                                    callExpr->getExprLoc());
 
   // Copy arguments back to stage variable
-  tempLoad = spvBuilder.createLoad(payloadArg->getType(), payloadStageVar,
-                                   payloadArg->getLocStart());
-  spvBuilder.createStore(payloadArgInst, tempLoad, callExpr->getExprLoc());
+  tempLoad = spvBuilder.createLoad(rayPayloadArg->getType(), rayPayloadStageVar,
+                                   rayPayloadArg->getLocStart());
+  spvBuilder.createStore(rayPayloadArgInst, tempLoad, callExpr->getExprLoc());
   return;
 }
 
+void SpirvEmitter::processDispatchMesh(const CallExpr *callExpr) {
+  // HLSL Func - void DispatchMesh(uint ThreadGroupCountX,
+  //                               uint ThreadGroupCountY,
+  //                               uint ThreadGroupCountZ,
+  //                               groupshared <structType> MeshPayload);
+  assert(callExpr->getNumArgs() == 4);
+  const auto args = callExpr->getArgs();
+  const auto loc = callExpr->getExprLoc();
+
+  // 1) create a barrier GroupMemoryBarrierWithGroupSync().
+  processIntrinsicMemoryBarrier(callExpr,
+                                /*isDevice*/ false,
+                                /*groupSync*/ true,
+                                /*isAllBarrier*/ false);
+
+  // 2) set TaskCountNV = threadX * threadY * threadZ.
+  auto *threadX = doExpr(args[0]);
+  auto *threadY = doExpr(args[1]);
+  auto *threadZ = doExpr(args[2]);
+  auto *var = declIdMapper.getBuiltinVar(spv::BuiltIn::TaskCountNV,
+                                         astContext.UnsignedIntTy, loc);
+  auto *taskCount = spvBuilder.createBinaryOp(
+      spv::Op::OpIMul, astContext.UnsignedIntTy, threadX,
+      spvBuilder.createBinaryOp(spv::Op::OpIMul, astContext.UnsignedIntTy,
+                                threadY, threadZ, loc),
+      loc);
+  spvBuilder.createStore(var, taskCount, loc);
+
+  // 3) create PerTaskNV out attribute block and store MeshPayload info.
+  const auto *sigPoint =
+      hlsl::SigPoint::GetSigPoint(hlsl::DXIL::SigPointKind::MSOut);
+  spv::StorageClass sc = spv::StorageClass::Output;
+  auto *payloadArg = doExpr(args[3]);
+  bool isValid = false;
+  if (const auto *implCastExpr = dyn_cast<CastExpr>(args[3])) {
+    if (const auto *arg = dyn_cast<DeclRefExpr>(implCastExpr->getSubExpr())) {
+      if (const auto *paramDecl = dyn_cast<VarDecl>(arg->getDecl())) {
+        if (paramDecl->hasAttr<HLSLGroupSharedAttr>()) {
+          isValid = declIdMapper.createPayloadStageVars(
+              sigPoint, sc, paramDecl, /*asInput=*/false, paramDecl->getType(),
+              "out.var", &payloadArg);
+        }
+      }
+    }
+  }
+  if (!isValid) {
+    emitError("expected groupshared object as argument to DispatchMesh()",
+              args[3]->getExprLoc());
+  }
+}
+
+void SpirvEmitter::processMeshOutputCounts(const CallExpr *callExpr) {
+  // HLSL Func - void SetMeshOutputCounts(uint numVertices, uint numPrimitives);
+  assert(callExpr->getNumArgs() == 2);
+  const auto args = callExpr->getArgs();
+  const auto loc = callExpr->getExprLoc();
+  auto *var = declIdMapper.getBuiltinVar(spv::BuiltIn::PrimitiveCountNV,
+                                         astContext.UnsignedIntTy, loc);
+  spvBuilder.createStore(var, doExpr(args[1]), loc);
+}
+
 SpirvConstant *SpirvEmitter::getValueZero(QualType type) {
   {
     QualType scalarType = {};
@@ -9653,10 +9961,24 @@ hlsl::ShaderModel::Kind SpirvEmitter::getShaderModelKind(StringRef stageName) {
     smk = hlsl::ShaderModel::Kind::Intersection;
     break;
   case 'a':
-    smk = hlsl::ShaderModel::Kind::AnyHit;
+    switch (stageName[1]) {
+    case 'm':
+      smk = hlsl::ShaderModel::Kind::Amplification;
+      break;
+    case 'n':
+      smk = hlsl::ShaderModel::Kind::AnyHit;
+      break;
+    }
     break;
   case 'm':
-    smk = hlsl::ShaderModel::Kind::Miss;
+    switch (stageName[1]) {
+    case 'e':
+      smk = hlsl::ShaderModel::Kind::Mesh;
+      break;
+    case 'i':
+      smk = hlsl::ShaderModel::Kind::Miss;
+      break;
+    }
     break;
   default:
     smk = hlsl::ShaderModel::Kind::Invalid;
@@ -9695,6 +10017,10 @@ SpirvEmitter::getSpirvShaderStage(hlsl::ShaderModel::Kind smk) {
     return spv::ExecutionModel::MissNV;
   case hlsl::ShaderModel::Kind::Callable:
     return spv::ExecutionModel::CallableNV;
+  case hlsl::ShaderModel::Kind::Mesh:
+    return spv::ExecutionModel::MeshNV;
+  case hlsl::ShaderModel::Kind::Amplification:
+    return spv::ExecutionModel::TaskNV;
   default:
     llvm_unreachable("invalid shader model kind");
     break;
@@ -9958,8 +10284,8 @@ bool SpirvEmitter::emitEntryFunctionWrapperForRayTracing(
     paramTypes.push_back(paramType);
 
     // Order of arguments is fixed
-    // Any-Hit/Closest-Hit : Arg 0 = payload(inout), Arg1 = attribute(in)
-    // Miss : Arg 0 = payload(inout)
+    // Any-Hit/Closest-Hit : Arg 0 = rayPayload(inout), Arg1 = attribute(in)
+    // Miss : Arg 0 = rayPayload(inout)
     // Callable : Arg 0 = callable data(inout)
     // Raygeneration/Intersection : No Args allowed
     if (sKind == hlsl::ShaderModel::Kind::RayGeneration) {
@@ -9970,7 +10296,7 @@ bool SpirvEmitter::emitEntryFunctionWrapperForRayTracing(
                sKind == hlsl::ShaderModel::Kind::AnyHit) {
       // Generate rayPayloadInNV and hitAttributeNV stage variables
       if (i == 0) {
-        // First argument is always payload
+        // First argument is always rayPayload
         curStageVar = declIdMapper.createRayTracingNVStageVar(
             spv::StorageClass::IncomingRayPayloadNV, param);
       } else {
@@ -9980,7 +10306,7 @@ bool SpirvEmitter::emitEntryFunctionWrapperForRayTracing(
       }
     } else if (sKind == hlsl::ShaderModel::Kind::Miss) {
       // Generate rayPayloadInNV stage variable
-      // First and only argument is payload
+      // First and only argument is rayPayload
       curStageVar = declIdMapper.createRayTracingNVStageVar(
           spv::StorageClass::IncomingRayPayloadNV, param);
     } else if (sKind == hlsl::ShaderModel::Kind::Callable) {
@@ -10020,6 +10346,166 @@ bool SpirvEmitter::emitEntryFunctionWrapperForRayTracing(
   return true;
 }
 
+bool SpirvEmitter::processMeshOrAmplificationShaderAttributes(
+    const FunctionDecl *decl, uint32_t *outVerticesArraySize) {
+  if (auto *numThreadsAttr = decl->getAttr<HLSLNumThreadsAttr>()) {
+    uint32_t x, y, z;
+    x = static_cast<uint32_t>(numThreadsAttr->getX());
+    y = static_cast<uint32_t>(numThreadsAttr->getY());
+    z = static_cast<uint32_t>(numThreadsAttr->getZ());
+    spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize,
+                                {x, y, z}, decl->getLocation());
+  }
+
+  // Early return for amplification shaders as they only take the 'numthreads'
+  // attribute.
+  if (spvContext.isAS())
+    return true;
+
+  spv::ExecutionMode outputPrimitive = spv::ExecutionMode::Max;
+  if (auto *outputTopology = decl->getAttr<HLSLOutputTopologyAttr>()) {
+    const auto topology = outputTopology->getTopology().lower();
+    outputPrimitive =
+        llvm::StringSwitch<spv::ExecutionMode>(topology)
+            .Case("point", spv::ExecutionMode::OutputPoints)
+            .Case("line", spv::ExecutionMode::OutputLinesNV)
+            .Case("triangle", spv::ExecutionMode::OutputTrianglesNV);
+    if (outputPrimitive != spv::ExecutionMode::Max) {
+      spvBuilder.addExecutionMode(entryFunction, outputPrimitive, {},
+                                  decl->getLocation());
+    } else {
+      emitError("unknown output topology in mesh shader",
+                outputTopology->getLocation());
+      return false;
+    }
+  }
+
+  uint32_t numVertices = 0;
+  uint32_t numIndices = 0;
+  uint32_t numPrimitives = 0;
+  bool payloadDeclSeen = false;
+
+  for (uint32_t i = 0; i < decl->getNumParams(); i++) {
+    const auto param = decl->getParamDecl(i);
+    const auto paramType = param->getType();
+    const auto paramLoc = param->getLocation();
+    if (param->hasAttr<HLSLVerticesAttr>() ||
+        param->hasAttr<HLSLIndicesAttr>() ||
+        param->hasAttr<HLSLPrimitivesAttr>()) {
+      uint32_t arraySize = 0;
+      if (const auto *arrayType =
+              astContext.getAsConstantArrayType(paramType)) {
+        const auto eleType =
+            arrayType->getElementType()->getCanonicalTypeUnqualified();
+        if (param->hasAttr<HLSLIndicesAttr>()) {
+          switch (outputPrimitive) {
+          case spv::ExecutionMode::OutputPoints:
+            if (eleType != astContext.UnsignedIntTy) {
+              emitError("expected 1D array of uint type", paramLoc);
+              return false;
+            }
+            break;
+          case spv::ExecutionMode::OutputLinesNV: {
+            QualType baseType;
+            uint32_t length;
+            if (!isVectorType(eleType, &baseType, &length) ||
+                baseType != astContext.UnsignedIntTy || length != 2) {
+              emitError("expected 1D array of uint2 type", paramLoc);
+              return false;
+            }
+            break;
+          }
+          case spv::ExecutionMode::OutputTrianglesNV: {
+            QualType baseType;
+            uint32_t length;
+            if (!isVectorType(eleType, &baseType, &length) ||
+                baseType != astContext.UnsignedIntTy || length != 3) {
+              emitError("expected 1D array of uint3 type", paramLoc);
+              return false;
+            }
+            break;
+          }
+          default:
+            assert(false && "unexpected spirv execution mode");
+          }
+        } else if (!eleType->isStructureType()) {
+          // vertices/primitives objects
+          emitError("expected 1D array of struct type", paramLoc);
+          return false;
+        }
+        arraySize = static_cast<uint32_t>(arrayType->getSize().getZExtValue());
+      } else {
+        emitError("expected 1D array of indices/vertices/primitives object",
+                  paramLoc);
+        return false;
+      }
+      if (param->hasAttr<HLSLVerticesAttr>()) {
+        if (numVertices != 0) {
+          emitError("only one object with 'vertices' modifier is allowed",
+                    paramLoc);
+          return false;
+        }
+        numVertices = arraySize;
+      } else if (param->hasAttr<HLSLIndicesAttr>()) {
+        if (numIndices != 0) {
+          emitError("only one object with 'indices' modifier is allowed",
+                    paramLoc);
+          return false;
+        }
+        numIndices = arraySize;
+      } else if (param->hasAttr<HLSLPrimitivesAttr>()) {
+        if (numPrimitives != 0) {
+          emitError("only one object with 'primitives' modifier is allowed",
+                    paramLoc);
+          return false;
+        }
+        numPrimitives = arraySize;
+      }
+    } else if (param->hasAttr<HLSLPayloadAttr>()) {
+      if (payloadDeclSeen) {
+        emitError("only one object with 'payload' modifier is allowed",
+                  paramLoc);
+        return false;
+      }
+      payloadDeclSeen = true;
+      if (!paramType->isStructureType()) {
+        emitError("expected payload of struct type", paramLoc);
+        return false;
+      }
+    }
+  }
+
+  // Vertex attribute array is a mandatory param to mesh entry function.
+  if (numVertices != 0) {
+    *outVerticesArraySize = numVertices;
+    spvBuilder.addExecutionMode(
+        entryFunction, spv::ExecutionMode::OutputVertices,
+        {static_cast<uint32_t>(numVertices)}, decl->getLocation());
+  } else {
+    emitError("expected vertices object declaration", decl->getLocation());
+    return false;
+  }
+
+  // Vertex indices array is a mandatory param to mesh entry function.
+  if (numIndices != 0) {
+    spvBuilder.addExecutionMode(
+        entryFunction, spv::ExecutionMode::OutputPrimitivesNV,
+        {static_cast<uint32_t>(numIndices)}, decl->getLocation());
+    // Primitive attribute array is an optional param to mesh entry function,
+    // but the array size should match the indices array.
+    if (numPrimitives != 0 && numPrimitives != numIndices) {
+      emitError("array size of primitives object should match 'indices' object",
+                decl->getLocation());
+      return false;
+    }
+  } else {
+    emitError("expected indices object declaration", decl->getLocation());
+    return false;
+  }
+
+  return true;
+}
+
 bool SpirvEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
                                             SpirvFunction *entryFuncInstr) {
   // HS specific attributes
@@ -10089,6 +10575,9 @@ bool SpirvEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
     if (!processGeometryShaderAttributes(decl, &inputArraySize))
       return false;
     // The per-vertex output of GS is not an array.
+  } else if (spvContext.isMS() || spvContext.isAS()) {
+    if (!processMeshOrAmplificationShaderAttributes(decl, &outputArraySize))
+      return false;
   }
 
   // Go through all parameters and record the declaration of SV_ClipDistance
@@ -10117,7 +10606,7 @@ bool SpirvEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
   // offset of SV_ClipDistance/SV_CullDistance variables within the array.
   declIdMapper.glPerVertex.calculateClipCullDistanceArraySize();
 
-  if (!spvContext.isCS()) {
+  if (!spvContext.isCS() && !spvContext.isAS()) {
     // Generate stand-alone builtins of Position, ClipDistance, and
     // CullDistance, which belongs to gl_PerVertex.
     declIdMapper.glPerVertex.generateVars(inputArraySize, outputArraySize);

+ 33 - 2
tools/clang/lib/SPIRV/SpirvEmitter.h

@@ -261,6 +261,24 @@ private:
   SpirvInstruction *tryToAssignToRWBufferRWTexture(const Expr *lhs,
                                                    SpirvInstruction *rhs);
 
+  /// Tries to emit instructions for assigning to the given mesh out attribute
+  /// or indices object. Returns 0 if the trial fails and no instructions are
+  /// generated.
+  SpirvInstruction *
+  tryToAssignToMSOutAttrsOrIndices(const Expr *lhs, SpirvInstruction *rhs,
+                                   SpirvInstruction *vecComponent = nullptr,
+                                   bool noWriteBack = false);
+
+  /// Emit instructions for assigning to the given mesh out attribute.
+  void assignToMSOutAttribute(
+      const DeclaratorDecl *decl, SpirvInstruction *value,
+      const llvm::SmallVector<SpirvInstruction *, 4> &indices);
+
+  /// Emit instructions for assigning to the given mesh out indices object.
+  void
+  assignToMSOutIndices(const DeclaratorDecl *decl, SpirvInstruction *value,
+                       const llvm::SmallVector<SpirvInstruction *, 4> &indices);
+
   /// Processes each vector within the given matrix by calling actOnEachVector.
   /// matrixVal should be the loaded value of the matrix. actOnEachVector takes
   /// three parameters for the current vector: the index, the <type-id>, and
@@ -295,7 +313,8 @@ private:
   const Expr *
   collectArrayStructIndices(const Expr *expr, bool rawIndex,
                             llvm::SmallVectorImpl<uint32_t> *rawIndices,
-                            llvm::SmallVectorImpl<SpirvInstruction *> *indices);
+                            llvm::SmallVectorImpl<SpirvInstruction *> *indices,
+                            bool *isMSOutAttribute = nullptr);
 
   /// Creates an access chain to index into the given SPIR-V evaluation result
   /// and returns the new SPIR-V evaluation result.
@@ -522,6 +541,12 @@ private:
   void processCallShader(const CallExpr *callExpr);
   void processTraceRay(const CallExpr *callExpr);
 
+  /// Process amplification shader intrinsics.
+  void processDispatchMesh(const CallExpr *callExpr);
+
+  /// Process mesh shader intrinsics.
+  void processMeshOutputCounts(const CallExpr *callExpr);
+
 private:
   /// Returns the <result-id> for constant value 0 of the given type.
   SpirvConstant *getValueZero(QualType type);
@@ -626,6 +651,12 @@ private:
   /// HLSL attributes of the entry point function.
   void processComputeShaderAttributes(const FunctionDecl *entryFunction);
 
+  /// \brief Adds necessary execution modes for the mesh/amplification shader
+  /// based on the HLSL attributes of the entry point function.
+  bool
+  processMeshOrAmplificationShaderAttributes(const FunctionDecl *decl,
+                                             uint32_t *outVerticesArraySize);
+
   /// \brief Emits a wrapper function for the entry function and returns true
   /// on success.
   ///
@@ -1106,7 +1137,7 @@ private:
   /// HitAttributeNV.
   llvm::SmallDenseMap<QualType,
                       std::pair<SpirvInstruction *, SpirvInstruction *>, 4>
-      payloadMap;
+      rayPayloadMap;
   llvm::SmallDenseMap<QualType, SpirvInstruction *, 4> hitAttributeMap;
   llvm::SmallDenseMap<QualType,
                       std::pair<SpirvInstruction *, SpirvInstruction *>, 4>

+ 32 - 0
tools/clang/lib/Sema/SemaExpr.cpp

@@ -608,6 +608,32 @@ static void DiagnoseDirectIsaAccess(Sema &S, const ObjCIvarRefExpr *OIRE,
     }
 }
 
+static bool IsExprAccessingMeshOutArray(Expr* BaseExpr) {
+  switch (BaseExpr->getStmtClass()) {
+  case Stmt::ArraySubscriptExprClass: {
+    ArraySubscriptExpr* ase = cast<ArraySubscriptExpr>(BaseExpr);
+    return IsExprAccessingMeshOutArray(ase->getBase());
+  }
+  case Stmt::ImplicitCastExprClass: {
+    ImplicitCastExpr* ice = cast<ImplicitCastExpr>(BaseExpr);
+    return IsExprAccessingMeshOutArray(ice->getSubExpr());
+  }
+  case Stmt::DeclRefExprClass: {
+    DeclRefExpr* dre = cast<DeclRefExpr>(BaseExpr);
+    ValueDecl* vd = dre->getDecl();
+    if (vd->getAttr<HLSLOutAttr>() &&
+        (vd->getAttr<HLSLIndicesAttr>() ||
+         vd->getAttr<HLSLVerticesAttr>() ||
+         vd->getAttr<HLSLPrimitivesAttr>())) {
+      return true;
+    }
+    return false;
+  }
+  default:
+    return false;
+  }
+}
+
 ExprResult Sema::DefaultLvalueConversion(Expr *E) {
   // Handle any placeholder expressions which made it here.
   if (E->getType()->isPlaceholderType()) {
@@ -666,6 +692,12 @@ ExprResult Sema::DefaultLvalueConversion(Expr *E) {
             dyn_cast<ObjCIvarRefExpr>(E->IgnoreParenCasts()))
     DiagnoseDirectIsaAccess(*this, OIRE, SourceLocation(), /* Expr*/nullptr);
 
+  // check the access to mesh shader output arrays
+  if (isa<ArraySubscriptExpr>(E) && IsExprAccessingMeshOutArray(E)) {
+    Diag(E->getExprLoc(), diag::err_hlsl_load_from_mesh_out_arrays);
+    return ExprError();
+  }
+
   // C++ [conv.lval]p1:
   //   [...] If T is a non-class type, the type of the prvalue is the
   //   cv-unqualified version of T. Otherwise, the type of the

+ 3 - 1
tools/clang/lib/Sema/SemaExprCXX.cpp

@@ -3143,7 +3143,9 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
   case ICK_Lvalue_To_Rvalue: {
     assert(From->getObjectKind() != OK_ObjCProperty);
     ExprResult FromRes = DefaultLvalueConversion(From);
-    assert(!FromRes.isInvalid() && "Can't perform deduced conversion?!");
+    if (FromRes.isInvalid()) {
+      return ExprError();
+    }
     From = FromRes.get();
     FromType = From->getType();
     break;

+ 245 - 15
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -171,6 +171,11 @@ enum ArBasicKind {
   AR_OBJECT_ROVTEXTURE2D_ARRAY,
   AR_OBJECT_ROVTEXTURE3D,
 
+  AR_OBJECT_FEEDBACKTEXTURE2D_MINLOD,
+  AR_OBJECT_FEEDBACKTEXTURE2D_TILED,
+  AR_OBJECT_FEEDBACKTEXTURE2D_ARRAY_MINLOD,
+  AR_OBJECT_FEEDBACKTEXTURE2D_ARRAY_TILED,
+
   // SPIRV change starts
 #ifdef ENABLE_SPIRV_CODEGEN
   AR_OBJECT_VK_SUBPASS_INPUT,
@@ -199,6 +204,9 @@ enum ArBasicKind {
   AR_OBJECT_TRIANGLE_HIT_GROUP,
   AR_OBJECT_PROCEDURAL_PRIMITIVE_HIT_GROUP,
 
+  // RayQuery
+  AR_OBJECT_RAY_QUERY,
+
   AR_BASIC_MAXIMUM_COUNT
 };
 
@@ -276,7 +284,8 @@ enum ArBasicKind {
 #define BPROP_PRIMITIVE         0x00100000  // Whether the type is a primitive scalar type.
 #define BPROP_MIN_PRECISION     0x00200000  // Whether the type is qualified with a minimum precision.
 #define BPROP_ROVBUFFER         0x00400000  // Whether the type is a ROV object.
-#define BPROP_ENUM              0x00800000  // Whether the type is a enum
+#define BPROP_FEEDBACKTEXTURE   0x00800000  // Whether the type is a feedback texture.
+#define BPROP_ENUM              0x01000000  // Whether the type is a enum
 
 #define GET_BPROP_PRIM_KIND(_Props) \
     ((_Props) & (BPROP_BOOLEAN | BPROP_INTEGER | BPROP_FLOATING))
@@ -448,6 +457,11 @@ const UINT g_uBasicKindProps[] =
   BPROP_OBJECT | BPROP_RWBUFFER | BPROP_ROVBUFFER,    // AR_OBJECT_ROVTEXTURE2D_ARRAY
   BPROP_OBJECT | BPROP_RWBUFFER | BPROP_ROVBUFFER,    // AR_OBJECT_ROVTEXTURE3D
 
+  BPROP_OBJECT | BPROP_TEXTURE | BPROP_FEEDBACKTEXTURE, // AR_OBJECT_FEEDBACKTEXTURE2D_MINLOD
+  BPROP_OBJECT | BPROP_TEXTURE | BPROP_FEEDBACKTEXTURE, // AR_OBJECT_FEEDBACKTEXTURE2D_TILED
+  BPROP_OBJECT | BPROP_TEXTURE | BPROP_FEEDBACKTEXTURE, // AR_OBJECT_FEEDBACKTEXTURE2D_ARRAY_MINLOD
+  BPROP_OBJECT | BPROP_TEXTURE | BPROP_FEEDBACKTEXTURE, // AR_OBJECT_FEEDBACKTEXTURE2D_ARRAY_TILED
+
   // SPIRV change starts
 #ifdef ENABLE_SPIRV_CODEGEN
   BPROP_OBJECT | BPROP_RBUFFER,   // AR_OBJECT_VK_SUBPASS_INPUT
@@ -476,6 +490,8 @@ const UINT g_uBasicKindProps[] =
   0,      //AR_OBJECT_TRIANGLE_HIT_GROUP,
   0,      //AR_OBJECT_PROCEDURAL_PRIMITIVE_HIT_GROUP,
 
+  0,      //AR_OBJECT_RAY_QUERY,
+
   // AR_BASIC_MAXIMUM_COUNT
 };
 
@@ -1092,6 +1108,18 @@ static const ArBasicKind g_SamplerCT[] =
   AR_BASIC_UNKNOWN
 };
 
+static const ArBasicKind g_Texture2DCT[] =
+{
+  AR_OBJECT_TEXTURE2D,
+  AR_BASIC_UNKNOWN
+};
+
+static const ArBasicKind g_Texture2DArrayCT[] =
+{
+  AR_OBJECT_TEXTURE2D_ARRAY,
+  AR_BASIC_UNKNOWN
+};
+
 static const ArBasicKind g_RayDescCT[] =
 {
   AR_OBJECT_RAY_DESC,
@@ -1203,8 +1231,11 @@ const ArBasicKind* g_LegalIntrinsicCompTypes[] =
   g_RayDescCT,          // LICOMPTYPE_RAYDESC
   g_AccelerationStructCT,   // LICOMPTYPE_ACCELERATION_STRUCT,
   g_UDTCT,              // LICOMPTYPE_USER_DEFINED_TYPE
+  g_Texture2DCT,        // LICOMPTYPE_TEXTURE2D
+  g_Texture2DArrayCT,   // LICOMPTYPE_TEXTURE2DARRAY
 };
-C_ASSERT(ARRAYSIZE(g_LegalIntrinsicCompTypes) == LICOMPTYPE_COUNT);
+static_assert(ARRAYSIZE(g_LegalIntrinsicCompTypes) == LICOMPTYPE_COUNT,
+  "Intrinsic comp type table must be updated when new enumerants are added.");
 
 // Decls.cpp constants ends here - these should be refactored or, better, replaced with clang::Type-based constructs.
 
@@ -1264,6 +1295,11 @@ const ArBasicKind g_ArBasicKindsAsTypes[] =
   AR_OBJECT_ROVTEXTURE2D_ARRAY,
   AR_OBJECT_ROVTEXTURE3D,
 
+  AR_OBJECT_FEEDBACKTEXTURE2D_MINLOD,
+  AR_OBJECT_FEEDBACKTEXTURE2D_TILED,
+  AR_OBJECT_FEEDBACKTEXTURE2D_ARRAY_MINLOD,
+  AR_OBJECT_FEEDBACKTEXTURE2D_ARRAY_TILED,
+
   // SPIRV change starts
 #ifdef ENABLE_SPIRV_CODEGEN
   AR_OBJECT_VK_SUBPASS_INPUT,
@@ -1286,7 +1322,9 @@ const ArBasicKind g_ArBasicKindsAsTypes[] =
   AR_OBJECT_RAYTRACING_SHADER_CONFIG,
   AR_OBJECT_RAYTRACING_PIPELINE_CONFIG,
   AR_OBJECT_TRIANGLE_HIT_GROUP,
-  AR_OBJECT_PROCEDURAL_PRIMITIVE_HIT_GROUP
+  AR_OBJECT_PROCEDURAL_PRIMITIVE_HIT_GROUP,
+
+  AR_OBJECT_RAY_QUERY
 };
 
 // Count of template arguments for basic kind of objects that look like templates (one or more type arguments).
@@ -1345,6 +1383,11 @@ const uint8_t g_ArBasicKindsTemplateCount[] =
   1, // AR_OBJECT_ROVTEXTURE2D_ARRAY
   1, // AR_OBJECT_ROVTEXTURE3D
 
+  0, // AR_OBJECT_FEEDBACKTEXTURE2D_MINLOD
+  0, // AR_OBJECT_FEEDBACKTEXTURE2D_TILED
+  0, // AR_OBJECT_FEEDBACKTEXTURE2D_ARRAY_MINLOD
+  0, // AR_OBJECT_FEEDBACKTEXTURE2D_ARRAY_TILED
+
   // SPIRV change starts
 #ifdef ENABLE_SPIRV_CODEGEN
   1, // AR_OBJECT_VK_SUBPASS_INPUT
@@ -1366,6 +1409,8 @@ const uint8_t g_ArBasicKindsTemplateCount[] =
   0, // AR_OBJECT_RAYTRACING_PIPELINE_CONFIG,
   0, // AR_OBJECT_TRIANGLE_HIT_GROUP,
   0, // AR_OBJECT_PROCEDURAL_PRIMITIVE_HIT_GROUP,
+
+  1, // AR_OBJECT_RAY_QUERY,
 };
 
 C_ASSERT(_countof(g_ArBasicKindsAsTypes) == _countof(g_ArBasicKindsTemplateCount));
@@ -1434,6 +1479,11 @@ const SubscriptOperatorRecord g_ArBasicKindsSubscripts[] =
   { 3, MipsFalse, SampleFalse }, // AR_OBJECT_ROVTEXTURE2D_ARRAY (ROVTexture2DArray)
   { 3, MipsFalse, SampleFalse }, // AR_OBJECT_ROVTEXTURE3D (ROVTexture3D)
 
+  { 0, MipsFalse, SampleFalse }, // AR_OBJECT_FEEDBACKTEXTURE2D_MINLOD
+  { 0, MipsFalse, SampleFalse }, // AR_OBJECT_FEEDBACKTEXTURE2D_TILED
+  { 0, MipsFalse, SampleFalse }, // AR_OBJECT_FEEDBACKTEXTURE2D_ARRAY_MINLOD
+  { 0, MipsFalse, SampleFalse }, // AR_OBJECT_FEEDBACKTEXTURE2D_ARRAY_TILED
+
   // SPIRV change starts
 #ifdef ENABLE_SPIRV_CODEGEN
   { 0, MipsFalse, SampleFalse }, // AR_OBJECT_VK_SUBPASS_INPUT (SubpassInput)
@@ -1456,6 +1506,7 @@ const SubscriptOperatorRecord g_ArBasicKindsSubscripts[] =
   { 0, MipsFalse, SampleFalse },  // AR_OBJECT_TRIANGLE_HIT_GROUP,
   { 0, MipsFalse, SampleFalse },  // AR_OBJECT_PROCEDURAL_PRIMITIVE_HIT_GROUP,
 
+  { 0, MipsFalse, SampleFalse },  // AR_OBJECT_RAY_QUERY,
 };
 
 C_ASSERT(_countof(g_ArBasicKindsAsTypes) == _countof(g_ArBasicKindsSubscripts));
@@ -1544,6 +1595,11 @@ const char* g_ArBasicTypeNames[] =
   "RasterizerOrderedTexture2DArray",
   "RasterizerOrderedTexture3D",
 
+  "FeedbackTexture2DMinLOD",
+  "FeedbackTexture2DTiled",
+  "FeedbackTexture2DArrayMinLOD",
+  "FeedbackTexture2DArrayTiled",
+
   // SPIRV change starts
 #ifdef ENABLE_SPIRV_CODEGEN
   "SubpassInput",
@@ -1568,7 +1624,9 @@ const char* g_ArBasicTypeNames[] =
   "RaytracingShaderConfig",
   "RaytracingPipelineConfig",
   "TriangleHitGroup",
-  "ProceduralPrimitiveHitGroup"
+  "ProceduralPrimitiveHitGroup",
+
+  "RayQuery"
 };
 
 C_ASSERT(_countof(g_ArBasicTypeNames) == AR_BASIC_MAXIMUM_COUNT);
@@ -2089,6 +2147,16 @@ void GetIntrinsicMethods(ArBasicKind kind, _Outptr_result_buffer_(*intrinsicCoun
     *intrinsics = g_RWTexture3DMethods;
     *intrinsicCount = _countof(g_RWTexture3DMethods);
     break;
+  case AR_OBJECT_FEEDBACKTEXTURE2D_MINLOD:
+  case AR_OBJECT_FEEDBACKTEXTURE2D_TILED:
+    *intrinsics = g_FeedbackTexture2DMethods;
+    *intrinsicCount = _countof(g_FeedbackTexture2DMethods);
+    break;
+  case AR_OBJECT_FEEDBACKTEXTURE2D_ARRAY_MINLOD:
+  case AR_OBJECT_FEEDBACKTEXTURE2D_ARRAY_TILED:
+    *intrinsics = g_FeedbackTexture2DArrayMethods;
+    *intrinsicCount = _countof(g_FeedbackTexture2DArrayMethods);
+    break;
   case AR_OBJECT_RWBUFFER:
   case AR_OBJECT_ROVBUFFER:
     *intrinsics = g_RWBufferMethods;
@@ -2120,7 +2188,11 @@ void GetIntrinsicMethods(ArBasicKind kind, _Outptr_result_buffer_(*intrinsicCoun
     *intrinsics = g_ConsumeStructuredBufferMethods;
     *intrinsicCount = _countof(g_ConsumeStructuredBufferMethods);
     break;
-  // SPIRV change starts
+  case AR_OBJECT_RAY_QUERY:
+    *intrinsics = g_RayQueryMethods;
+    *intrinsicCount = _countof(g_RayQueryMethods);
+    break;
+    // SPIRV change starts
 #ifdef ENABLE_SPIRV_CODEGEN
   case AR_OBJECT_VK_SUBPASS_INPUT:
     *intrinsics = g_VkSubpassInputMethods;
@@ -3219,6 +3291,12 @@ private:
           recordDecl = CreateSubobjectProceduralPrimitiveHitGroup(*m_context);
           break;
         }
+      } else if (kind == AR_OBJECT_RAY_QUERY) {
+        ClassTemplateDecl* typeDecl = nullptr;
+        AddRayQueryTemplate(*m_context, &typeDecl, &recordDecl);
+        DXASSERT(typeDecl != nullptr, "AddRayQueryTemplate failed to return the object declaration");
+        typeDecl->setImplicit(true);
+        recordDecl->setImplicit(true);
       }
       else if (templateArgCount == 0)
       {
@@ -3419,6 +3497,13 @@ public:
     return IsSubobjectBasicKind(GetTypeElementKind(type));
   }
 
+  bool IsRayQueryBasicKind(ArBasicKind kind) {
+    return kind == AR_OBJECT_RAY_QUERY;
+  }
+  bool IsRayQueryType(QualType type) {
+    return IsRayQueryBasicKind(GetTypeElementKind(type));
+  }
+
   void WarnMinPrecision(HLSLScalarType type, SourceLocation loc) {
     // TODO: enalbe this once we introduce precise master option
     bool UseMinPrecision = m_context->getLangOpts().UseMinPrecision;
@@ -4155,6 +4240,8 @@ public:
     AddRayFlags(*m_context);
     AddHitKinds(*m_context);
     AddStateObjectFlags(*m_context);
+    AddCommittedStatus(*m_context);
+    AddCandidateType(*m_context);
 
     return true;
   }
@@ -4882,7 +4969,8 @@ QualType GetFirstElementTypeFromDecl(const Decl* decl)
   if (specialization) {
     const TemplateArgumentList& list = specialization->getTemplateArgs();
     if (list.size()) {
-      return list[0].getAsType();
+      if (list[0].getKind() == TemplateArgument::ArgKind::Type)
+        return list[0].getAsType();
     }
   }
 
@@ -5603,7 +5691,12 @@ bool HLSLExternalSource::MatchArguments(
         return false;
       }
       pNewType = objectElement;
-    } else {
+    }
+    else if (pArgument->uLegalComponentTypes == LICOMPTYPE_TEXTURE2D
+      || pArgument->uLegalComponentTypes == LICOMPTYPE_TEXTURE2DARRAY) {
+      pNewType = Args[i - 1]->getType().getNonReferenceType();
+    }
+    else {
       ArBasicKind pEltType;
 
       // ComponentType, if the Id is special then it gets the
@@ -7323,6 +7416,29 @@ VectorMemberAccessError TryParseVectorMemberAccess(_In_z_ const char* memberText
   return VectorMemberAccessError_None;
 }
 
+static bool IsExprAccessingOutIndicesArray(Expr* BaseExpr) {
+  switch(BaseExpr->getStmtClass()) {
+  case Stmt::ArraySubscriptExprClass: {
+    ArraySubscriptExpr* ase = cast<ArraySubscriptExpr>(BaseExpr);
+    return IsExprAccessingOutIndicesArray(ase->getBase());
+  }
+  case Stmt::ImplicitCastExprClass: {
+    ImplicitCastExpr* ice = cast<ImplicitCastExpr>(BaseExpr);
+    return IsExprAccessingOutIndicesArray(ice->getSubExpr());
+  }
+  case Stmt::DeclRefExprClass: {
+    DeclRefExpr* dre = cast<DeclRefExpr>(BaseExpr);
+    ValueDecl* vd = dre->getDecl();
+    if (vd->getAttr<HLSLIndicesAttr>() && vd->getAttr<HLSLOutAttr>()) {
+      return true;
+    }
+    return false;
+  }
+  default:
+    return false;
+  }
+}
+
 bool HLSLExternalSource::LookupVectorMemberExprForHLSL(
     Expr& BaseExpr,
     DeclarationName MemberName,
@@ -7396,6 +7512,14 @@ bool HLSLExternalSource::LookupVectorMemberExprForHLSL(
 
   DXASSERT(positions.IsValid, "otherwise an error should have been returned");
 
+  // Disallow component access for out indices for DXIL path. We still allow
+  // this in SPIR-V path.
+  if (!getSema()->getLangOpts().SPIRV &&
+      IsExprAccessingOutIndicesArray(&BaseExpr) && positions.Count < colCount) {
+    m_sema->Diag(MemberLoc, diag::err_hlsl_out_indices_array_incorrect_access);
+    return false;
+  }
+
   // Consume elements
   QualType resultType;
   if (positions.Count == 1)
@@ -9624,6 +9748,10 @@ void hlsl::DiagnoseRegisterType(
   case AR_OBJECT_ROVTEXTURE2D:
   case AR_OBJECT_ROVTEXTURE2D_ARRAY:
   case AR_OBJECT_ROVTEXTURE3D:
+  case AR_OBJECT_FEEDBACKTEXTURE2D_MINLOD:
+  case AR_OBJECT_FEEDBACKTEXTURE2D_TILED:
+  case AR_OBJECT_FEEDBACKTEXTURE2D_ARRAY_MINLOD:
+  case AR_OBJECT_FEEDBACKTEXTURE2D_ARRAY_TILED:
     expected = "'u'";
     isValid = registerType == 'u';
     break;
@@ -9713,8 +9841,9 @@ void hlsl::DiagnoseTranslationUnit(clang::Sema *self) {
     if (shaderModel->IsGS()) {
       // Validate that GS has the maxvertexcount attribute
       if (!pEntryPointDecl->hasAttr<HLSLMaxVertexCountAttr>()) {
-        self->Diag(pEntryPointDecl->getLocation(),
-                   diag::err_hlsl_missing_maxvertexcount_attr);
+        self->Diag(pEntryPointDecl->getLocation(), diag::err_hlsl_missing_attr)
+            << "GS"
+            << "maxvertexcount";
         return;
       }
     } else if (shaderModel->IsHS()) {
@@ -9731,8 +9860,32 @@ void hlsl::DiagnoseTranslationUnit(clang::Sema *self) {
         }
         pPatchFnDecl = NL.Found;
       } else {
-        self->Diag(pEntryPointDecl->getLocation(),
-                   diag::err_hlsl_missing_patchconstantfunc_attr);
+        self->Diag(pEntryPointDecl->getLocation(), diag::err_hlsl_missing_attr)
+            << "HS"
+            << "patchconstantfunc";
+        return;
+      }
+    } else if (shaderModel->IsMS()) {
+      // Validate that MS has the numthreads attribute
+      if (!pEntryPointDecl->hasAttr<HLSLNumThreadsAttr>()) {
+        self->Diag(pEntryPointDecl->getLocation(), diag::err_hlsl_missing_attr)
+            << "MS"
+            << "numthreads";
+        return;
+      }
+      // Validate that MS has the outputtopology attribute
+      if (!pEntryPointDecl->hasAttr<HLSLOutputTopologyAttr>()) {
+        self->Diag(pEntryPointDecl->getLocation(), diag::err_hlsl_missing_attr)
+            << "MS"
+            << "outputtopology";
+        return;
+      }
+    } else if (shaderModel->IsAS()) {
+      // Validate that AS has the numthreads attribute
+      if (!pEntryPointDecl->hasAttr<HLSLNumThreadsAttr>()) {
+        self->Diag(pEntryPointDecl->getLocation(), diag::err_hlsl_missing_attr)
+            << "AS"
+            << "numthreads";
         return;
       }
     }
@@ -10892,6 +11045,22 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A,
     declAttr = ::new (S.Context) HLSLGloballyCoherentAttr(
         A.getRange(), S.Context, A.getAttributeSpellingListIndex());
     break;
+  case AttributeList::AT_HLSLIndices:
+    declAttr = ::new (S.Context) HLSLIndicesAttr(
+        A.getRange(), S.Context, A.getAttributeSpellingListIndex());
+    break;
+  case AttributeList::AT_HLSLVertices:
+    declAttr = ::new (S.Context) HLSLVerticesAttr(
+        A.getRange(), S.Context, A.getAttributeSpellingListIndex());
+    break;
+  case AttributeList::AT_HLSLPrimitives:
+    declAttr = ::new (S.Context) HLSLPrimitivesAttr(
+        A.getRange(), S.Context, A.getAttributeSpellingListIndex());
+    break;
+  case AttributeList::AT_HLSLPayload:
+    declAttr = ::new (S.Context) HLSLPayloadAttr(
+        A.getRange(), S.Context, A.getAttributeSpellingListIndex());
+    break;
 
   default:
     Handled = false;
@@ -10975,8 +11144,10 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A,
   case AttributeList::AT_HLSLShader:
     declAttr = ::new (S.Context) HLSLShaderAttr(
         A.getRange(), S.Context,
-        ValidateAttributeStringArg(S, A,
-                                   "compute,vertex,pixel,hull,domain,geometry,raygeneration,intersection,anyhit,closesthit,miss,callable"),
+        ValidateAttributeStringArg(
+            S, A,
+            "compute,vertex,pixel,hull,domain,geometry,raygeneration,"
+            "intersection,anyhit,closesthit,miss,callable,mesh,amplification"),
         A.getAttributeSpellingListIndex());
     break;
   case AttributeList::AT_HLSLMaxVertexCount:
@@ -11017,7 +11188,7 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A,
   {
   case AttributeList::AT_VKBuiltIn:
     declAttr = ::new (S.Context) VKBuiltInAttr(A.getRange(), S.Context,
-      ValidateAttributeStringArg(S, A, "PointSize,HelperInvocation,BaseVertex,BaseInstance,DrawIndex,DeviceIndex"),
+      ValidateAttributeStringArg(S, A, "PointSize,HelperInvocation,BaseVertex,BaseInstance,DrawIndex,DeviceIndex,ViewportMaskNV"),
       A.getAttributeSpellingListIndex());
     break;
   case AttributeList::AT_VKLocation:
@@ -11549,7 +11720,8 @@ bool Sema::DiagnoseHLSLDecl(Declarator &D, DeclContext *DC, Expr *BitWidth,
     *pCentroid = nullptr,
     *pCenter = nullptr,
     *pAnyLinear = nullptr,                   // first linear attribute found
-    *pTopology = nullptr;
+    *pTopology = nullptr,
+    *pMeshModifier = nullptr;
   bool usageIn = false;
   bool usageOut = false;
 
@@ -11733,6 +11905,29 @@ bool Sema::DiagnoseHLSLDecl(Declarator &D, DeclContext *DC, Expr *BitWidth,
       }
       break;
 
+    case AttributeList::AT_HLSLIndices:
+    case AttributeList::AT_HLSLVertices:
+    case AttributeList::AT_HLSLPrimitives:
+    case AttributeList::AT_HLSLPayload:
+      if (!(isParameter)) {
+        Diag(pAttr->getLoc(), diag::err_hlsl_varmodifierna)
+          << pAttr->getName() << declarationType << pAttr->getRange();
+        result = false;
+      }
+      if (pMeshModifier) {
+        if (pMeshModifier->getKind() == pAttr->getKind()) {
+          Diag(pAttr->getLoc(), diag::warn_hlsl_duplicate_specifier)
+            << pAttr->getName() << pAttr->getRange();
+        } else {
+          Diag(pAttr->getLoc(), diag::err_hlsl_varmodifiersna)
+            << pAttr->getName() << pMeshModifier->getName()
+            << declarationType << pAttr->getRange();
+          result = false;
+        }
+      }
+      pMeshModifier = pAttr;
+      break;
+
     default:
       break;
     }
@@ -11784,6 +11979,21 @@ bool Sema::DiagnoseHLSLDecl(Declarator &D, DeclContext *DC, Expr *BitWidth,
       result = false;
     }
   }
+  if (pMeshModifier) {
+    if (pMeshModifier->getKind() == AttributeList::Kind::AT_HLSLPayload) {
+      if (!usageIn) {
+        Diag(D.getLocStart(), diag::err_hlsl_missing_in_attr)
+            << pMeshModifier->getName();
+        result = false;
+      }
+    } else {
+      if (!usageOut) {
+        Diag(D.getLocStart(), diag::err_hlsl_missing_out_attr)
+            << pMeshModifier->getName();
+        result = false;
+      }
+    }
+  }
 
   // Validate that stream-ouput objects are marked as inout
   if (isParameter && !(usageIn && usageOut) &&
@@ -12313,6 +12523,22 @@ void hlsl::CustomPrintHLSLAttr(const clang::Attr *A, llvm::raw_ostream &Out, con
     Out << "globallycoherent ";
     break;
 
+  case clang::attr::HLSLIndices:
+    Out << "indices ";
+    break;
+
+  case clang::attr::HLSLVertices:
+    Out << "vertices ";
+    break;
+
+  case clang::attr::HLSLPrimitives:
+    Out << "primitives ";
+    break;
+
+  case clang::attr::HLSLPayload:
+    Out << "payload ";
+    break;
+
   default:
     A->printPretty(Out, Policy);
     break;
@@ -12365,6 +12591,10 @@ bool hlsl::IsHLSLAttr(clang::attr::Kind AttrKind) {
   case clang::attr::HLSLTriangle:
   case clang::attr::HLSLTriangleAdj:
   case clang::attr::HLSLGloballyCoherent:
+  case clang::attr::HLSLIndices:
+  case clang::attr::HLSLVertices:
+  case clang::attr::HLSLPrimitives:
+  case clang::attr::HLSLPayload:
   case clang::attr::NoInline:
   case clang::attr::HLSLExport:
   case clang::attr::VKBinding:

A különbségek nem kerülnek megjelenítésre, a fájl túl nagy
+ 153 - 132
tools/clang/lib/Sema/gen_intrin_main_tables_15.h


+ 66 - 0
tools/clang/test/CodeGenHLSL/batch/declarations/resources/textures/feedback.hlsl

@@ -0,0 +1,66 @@
+// RUN: %dxc -E main -T ps_6_5 %s | FileCheck %s
+
+// Test FeedbackTexture2D*** and their WriteSamplerFeedback methods
+
+FeedbackTexture2DMinLOD feedbackMinLOD;
+FeedbackTexture2DTiled feedbackTiled;
+FeedbackTexture2DArrayMinLOD feedbackMinLODArray;
+FeedbackTexture2DArrayTiled feebackTiledArray;
+Texture2D<float> texture2D;
+Texture2D<float4> texture2D_float4;
+Texture2DArray<float> texture2DArray;
+SamplerState samp;
+
+float main() : SV_Target
+{
+    float2 coords2D = float2(1, 2);
+    float3 coords2DArray = float3(1, 2, 3);
+    float clamp = 4;
+    float bias = 5;
+    float lod = 6;
+    float ddx = 7;
+    float ddy = 8;
+    
+    // Test every dxil intrinsic
+    // CHECK: call void @dx.op.writeSamplerFeedback(
+    // CHECK: float 1.000000e+00, float 2.000000e+00, float undef, float 4.000000e+00)
+    feedbackMinLOD.WriteSamplerFeedback(texture2D, samp, coords2D, clamp);
+    // CHECK: call void @dx.op.writeSamplerFeedbackBias(
+    // CHECK: float 1.000000e+00, float 2.000000e+00, float undef, float 5.000000e+00, float 4.000000e+00)
+    feedbackMinLOD.WriteSamplerFeedbackBias(texture2D, samp, coords2D, bias, clamp);
+    // CHECK: call void @dx.op.writeSamplerFeedbackLevel(
+    // CHECK: float 1.000000e+00, float 2.000000e+00, float undef, float 6.000000e+00)
+    feedbackMinLOD.WriteSamplerFeedbackLevel(texture2D, samp, coords2D, lod);
+    // CHECK: call void @dx.op.writeSamplerFeedbackGrad(
+    // CHECK: float 1.000000e+00, float 2.000000e+00, float undef, float 7.000000e+00, float 8.000000e+00, float 4.000000e+00)
+    feedbackMinLOD.WriteSamplerFeedbackGrad(texture2D, samp, coords2D, ddx, ddy, clamp);
+    
+    // Test with undef clamp
+    // CHECK: call void @dx.op.writeSamplerFeedback(
+    // CHECK: float 1.000000e+00, float 2.000000e+00, float undef, float undef)
+    feedbackMinLOD.WriteSamplerFeedback(texture2D, samp, coords2D);
+    // CHECK: call void @dx.op.writeSamplerFeedbackBias(
+    // CHECK: float 1.000000e+00, float 2.000000e+00, float undef, float 5.000000e+00, float undef)
+    feedbackMinLOD.WriteSamplerFeedbackBias(texture2D, samp, coords2D, bias);
+    // CHECK: call void @dx.op.writeSamplerFeedbackGrad(
+    // CHECK: float 1.000000e+00, float 2.000000e+00, float undef, float 7.000000e+00, float 8.000000e+00, float undef)
+    feedbackMinLOD.WriteSamplerFeedbackGrad(texture2D, samp, coords2D, ddx, ddy);
+
+    // Test on every FeedbackTexture variant
+    // CHECK: call void @dx.op.writeSamplerFeedback(
+    // CHECK: float 1.000000e+00, float 2.000000e+00, float undef, float undef)
+    feedbackTiled.WriteSamplerFeedback(texture2D, samp, coords2D);
+    // CHECK: call void @dx.op.writeSamplerFeedback(
+    // CHECK: float 1.000000e+00, float 2.000000e+00, float 3.000000e+00, float undef)
+    feedbackMinLODArray.WriteSamplerFeedback(texture2DArray, samp, coords2DArray);
+    // CHECK: call void @dx.op.writeSamplerFeedback(
+    // CHECK: float 1.000000e+00, float 2.000000e+00, float 3.000000e+00, float undef)
+    feebackTiledArray.WriteSamplerFeedback(texture2DArray, samp, coords2DArray);
+
+    // Test with overloaded texture type
+    // CHECK: call void @dx.op.writeSamplerFeedback(
+    // CHECK: float 1.000000e+00, float 2.000000e+00, float undef, float undef)
+    feedbackMinLOD.WriteSamplerFeedback(texture2D_float4, samp, coords2D);
+
+    return 0;
+}

+ 1 - 1
tools/clang/test/CodeGenHLSL/batch/expressions/intrinsics/misc/abs1.hlsl

@@ -2,7 +2,7 @@
 
 // CHECK: main
 // After lowering, these would turn into multiple abs calls rather than a 4 x float
-// CHECK: call <4 x float> @"dx.hl.op..<4 x float> (i32, <4 x float>)"(i32 91,
+// CHECK: call <4 x float> @"dx.hl.op..<4 x float> (i32, <4 x float>)"(i32 94,
 
 float4 main(float4 a : A) : SV_TARGET {
   return abs(a*a.yxxx);

+ 19 - 0
tools/clang/test/CodeGenHLSL/batch/shader_stages/raytracing/rayquery/tracerayinline.hlsl

@@ -0,0 +1,19 @@
+// RUN: %dxc -T vs_6_5 -E main %s | FileCheck %s
+
+// CHECK: %[[RTAS:[^ ]+]] = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 0, i32 0, i32 0, i1 false)
+// CHECK: %[[RQ:[^ ]+]] = call i32 @dx.op.allocateRayQuery(i32 178, i32 1)
+// CHECK: call void @dx.op.rayQuery_TraceRayInline(i32 179, i32 %[[RQ]], %dx.types.Handle %[[RTAS]], i32 0, i32 1,
+// CHECK: call void @dx.op.rayQuery_TraceRayInline(i32 179, i32 %[[RQ]], %dx.types.Handle %[[RTAS]], i32 1, i32 2,
+
+RaytracingAccelerationStructure RTAS;
+
+void DoTrace(RayQuery<RAY_FLAG_FORCE_OPAQUE> rayQuery, RayDesc rayDesc) {
+  rayQuery.TraceRayInline(RTAS, 0, 1, rayDesc);
+}
+
+float main(RayDesc rayDesc : RAYDESC) : OUT {
+  RayQuery<RAY_FLAG_FORCE_OPAQUE> rayQuery;
+  DoTrace(rayQuery, rayDesc);
+  rayQuery.TraceRayInline(RTAS, 1, 2, rayDesc);
+  return 0;
+}

+ 205 - 0
tools/clang/test/CodeGenHLSL/batch/shader_stages/raytracing/rayquery/tryAllOps.hlsl

@@ -0,0 +1,205 @@
+// RUN: %dxc -T cs_6_5 -E CS %s | FileCheck %s
+
+// CHECK: define void @CS()
+
+// RayQuery alloca should have been dead-code eliminated
+// CHECK-NOT: alloca
+
+// CHECK: %[[hAccelerationStructure:[^ ]+]] = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 0, i32 0, i32 0, i1 false)
+// CHECK: %[[hRayQuery:[^ ]+]] = call i32 @dx.op.allocateRayQuery(i32 178, i32 5)
+// CHECK: call void @dx.op.rayQuery_TraceRayInline(i32 179, i32 %[[hRayQuery]], %dx.types.Handle %[[hAccelerationStructure]], i32 0, i32 255, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 9.999000e+03)
+// CHECK: call i1 @dx.op.rayQuery_Proceed.i1(i32 180, i32 %[[hRayQuery]])
+// CHECK: call i32 @dx.op.rayQuery_StateScalar.i32(i32 185, i32 %[[hRayQuery]])
+// CHECK: call void @dx.op.rayQuery_Abort(i32 181, i32 %[[hRayQuery]])
+// CHECK: call float @dx.op.rayQuery_StateMatrix.f32(i32 186, i32 %[[hRayQuery]], i32 0, i8 0)
+// CHECK: call void @dx.op.rayQuery_CommitNonOpaqueTriangleHit(i32 182, i32 %[[hRayQuery]])
+// CHECK: call i1 @dx.op.rayQuery_StateScalar.i1(i32 191, i32 %[[hRayQuery]])
+// CHECK: call float @dx.op.rayQuery_StateVector.f32(i32 193, i32 %[[hRayQuery]], i8 0)
+// CHECK: call i32 @dx.op.rayQuery_StateScalar.i32(i32 203, i32 %[[hRayQuery]])
+// CHECK: call i32 @dx.op.rayQuery_StateScalar.i32(i32 202, i32 %[[hRayQuery]])
+// CHECK: call i32 @dx.op.rayQuery_StateScalar.i32(i32 201, i32 %[[hRayQuery]])
+// CHECK: call float @dx.op.rayQuery_StateVector.f32(i32 206, i32 %[[hRayQuery]], i8 0)
+// CHECK: call float @dx.op.rayQuery_StateVector.f32(i32 205, i32 %[[hRayQuery]], i8 1)
+// CHECK: call i32 @dx.op.rayQuery_StateScalar.i32(i32 204, i32 %[[hRayQuery]])
+// CHECK: call float @dx.op.rayQuery_StateScalar.f32(i32 199, i32 %[[hRayQuery]])
+// CHECK: call i1 @dx.op.rayQuery_Proceed.i1(i32 180, i32 %[[hRayQuery]])
+// CHECK: call float @dx.op.rayQuery_StateMatrix.f32(i32 187, i32 %[[hRayQuery]], i32 0, i8 0)
+// CHECK: call i1 @dx.op.rayQuery_StateScalar.i1(i32 190, i32 %[[hRayQuery]])
+// CHECK: call void @dx.op.rayQuery_CommitProceduralPrimitiveHit(i32 183, i32 %[[hRayQuery]], float 5.000000e-01)
+// CHECK: call void @dx.op.rayQuery_Abort(i32 181, i32 %[[hRayQuery]])
+// CHECK: call i32 @dx.op.rayQuery_StateScalar.i32(i32 184, i32 %[[hRayQuery]])
+// CHECK: call float @dx.op.rayQuery_StateMatrix.f32(i32 188, i32 %[[hRayQuery]], i32 0, i8 0)
+// CHECK: call float @dx.op.rayQuery_StateMatrix.f32(i32 189, i32 %[[hRayQuery]], i32 0, i8 0)
+// CHECK: call i1 @dx.op.rayQuery_StateScalar.i1(i32 192, i32 %[[hRayQuery]])
+// CHECK: call float @dx.op.rayQuery_StateVector.f32(i32 194, i32 %[[hRayQuery]], i8 1)
+// CHECK: call i32 @dx.op.rayQuery_StateScalar.i32(i32 209, i32 %[[hRayQuery]])
+// CHECK: call i32 @dx.op.rayQuery_StateScalar.i32(i32 208, i32 %[[hRayQuery]])
+// CHECK: call i32 @dx.op.rayQuery_StateScalar.i32(i32 207, i32 %[[hRayQuery]])
+// CHECK: call float @dx.op.rayQuery_StateVector.f32(i32 212, i32 %[[hRayQuery]], i8 2)
+// CHECK: call float @dx.op.rayQuery_StateVector.f32(i32 211, i32 %[[hRayQuery]], i8 0)
+// CHECK: call i32 @dx.op.rayQuery_StateScalar.i32(i32 210, i32 %[[hRayQuery]])
+// CHECK: call float @dx.op.rayQuery_StateScalar.f32(i32 200, i32 %[[hRayQuery]])
+// CHECK: call i32 @dx.op.rayQuery_StateScalar.i32(i32 195, i32 %[[hRayQuery]])
+// CHECK: call float @dx.op.rayQuery_StateScalar.f32(i32 198, i32 %[[hRayQuery]])
+// CHECK: call float @dx.op.rayQuery_StateVector.f32(i32 197, i32 %[[hRayQuery]], i8 0)
+// CHECK: call float @dx.op.rayQuery_StateVector.f32(i32 196, i32 %[[hRayQuery]], i8 2)
+
+RaytracingAccelerationStructure AccelerationStructure : register(t0);
+RWByteAddressBuffer log : register(u0);
+
+RayDesc MakeRayDesc()
+{
+    RayDesc desc;
+    desc.Origin = float3(0,0,0);
+    desc.Direction = float3(1,0,0);
+    desc.TMin = 0.0f;
+    desc.TMax = 9999.0;
+    return desc;
+}
+
+void DoSomething()
+{
+    log.Store(0,1);
+}
+
+[numThreads(1,1,1)]
+void CS()
+{
+    RayQuery<RAY_FLAG_FORCE_OPAQUE|RAY_FLAG_ACCEPT_FIRST_HIT_AND_END_SEARCH> q;
+    RayDesc ray = MakeRayDesc();
+    q.TraceRayInline(AccelerationStructure,RAY_FLAG_NONE,0xFF,ray);
+    float4x3 mat4x3;
+    float3x4 mat3x4;
+    while(q.Proceed())
+    {
+        switch(q.CandidateType())
+        {
+        case CANDIDATE_NON_OPAQUE_TRIANGLE:
+            q.Abort();
+            mat3x4 = q.CandidateObjectToWorld3x4();
+            mat4x3 = q.CandidateObjectToWorld4x3();
+            q.CommitNonOpaqueTriangleHit();
+            if(q.CandidateTriangleFrontFace())
+            {
+                DoSomething();
+            }
+            if(q.CandidateTriangleBarycentrics().x == 0)
+            {
+                DoSomething();
+            }
+            if(q.CandidateGeometryIndex())
+            {
+                DoSomething();
+            }
+            if(q.CandidateInstanceID())
+            {
+                DoSomething();
+            }
+            if(q.CandidateInstanceIndex())
+            {
+                DoSomething();
+            }
+            if(q.CandidateObjectRayDirection().x)
+            {
+                DoSomething();
+            }
+            if(q.CandidateObjectRayOrigin().y)
+            {
+                DoSomething();
+            }
+            if(q.CandidatePrimitiveIndex())
+            {
+                DoSomething();
+            }
+            if(q.CandidateTriangleRayT())
+            {
+                DoSomething();
+            }
+            break;
+        case CANDIDATE_PROCEDURAL_PRIMITIVE:
+        {
+            mat3x4 = q.CandidateWorldToObject3x4();
+            mat4x3 = q.CandidateWorldToObject4x3();
+            if(q.CandidateProceduralPrimitiveNonOpaque())
+            {
+                DoSomething();
+            }
+            float t = 0.5;
+            q.CommitProceduralPrimitiveHit(t);
+            q.Abort();
+            break;
+        }
+        }
+    }
+    if(mat3x4[0][0] == mat4x3[0][0])
+    {
+        DoSomething();
+    }
+    switch(q.CommittedStatus())
+    {
+    case COMMITTED_NOTHING:
+        mat3x4 = q.CommittedObjectToWorld3x4();
+        mat4x3 = q.CommittedObjectToWorld4x3();
+        break;
+    case COMMITTED_TRIANGLE_HIT:
+        mat3x4 = q.CommittedWorldToObject3x4();
+        mat4x3 = q.CommittedWorldToObject4x3();
+        if(q.CommittedTriangleFrontFace())
+        {
+            DoSomething();
+        }
+        if(q.CommittedTriangleBarycentrics().y == 0)
+        {
+            DoSomething();
+        }
+        break;
+    case COMMITTED_PROCEDURAL_PRIMITIVE_HIT:
+        if(q.CommittedGeometryIndex())
+        {
+            DoSomething();
+        }
+        if(q.CommittedInstanceID())
+        {
+            DoSomething();
+        }
+        if(q.CommittedInstanceIndex())
+        {
+            DoSomething();
+        }
+        if(q.CommittedObjectRayDirection().z)
+        {
+            DoSomething();
+        }
+        if(q.CommittedObjectRayOrigin().x)
+        {
+            DoSomething();
+        }
+        if(q.CommittedPrimitiveIndex())
+        {
+            DoSomething();
+        }
+        if(q.CommittedRayT())
+        {
+            DoSomething();
+        }
+        break;
+    }
+    if(mat3x4[0][0] == mat4x3[0][0])
+    {
+        DoSomething();
+    }
+    if(q.RayFlags())
+    {
+        DoSomething();
+    }
+    if(q.RayTMin())
+    {
+        DoSomething();
+    }
+    float3 o = q.WorldRayDirection();
+    float3 d = q.WorldRayOrigin();
+    if(o.x == d.z)
+    {
+        DoSomething();
+    }
+}

+ 39 - 0
tools/clang/test/CodeGenHLSL/batch/shader_stages/raytracing/raytracing_anyhit_geometryIndex.hlsl

@@ -0,0 +1,39 @@
+// RUN: %dxc -T lib_6_5 -auto-binding-space 11 %s | FileCheck %s
+
+// CHECK: define void [[anyhit1:@"\\01\?anyhit1@[^\"]+"]](%struct.MyPayload* noalias nocapture %payload, %struct.MyAttributes* nocapture readonly %attr) #0 {
+// CHECK:   call float @dx.op.objectRayOrigin.f32(i32 149, i8 2)
+// CHECK:   call float @dx.op.objectRayDirection.f32(i32 150, i8 2)
+// CHECK:   call float @dx.op.rayTCurrent.f32(i32 154)
+// CHECK:   call void @dx.op.acceptHitAndEndSearch(i32 156)
+// CHECK:   call void @dx.op.ignoreHit(i32 155)
+// CHECK:   [[GeometryIndex:%[^ ]+]] = call i32 @dx.op.geometryIndex.i32(i32 213)
+// CHECK:   icmp eq i32 [[GeometryIndex]], 0
+// CHECK:   %[[color:[^ ]+]] = getelementptr inbounds %struct.MyPayload, %struct.MyPayload* %payload, i32 0, i32 0
+// CHECK:   store <4 x float> {{.*}}, <4 x float>* %[[color]], align 4
+// CHECK:   ret void
+
+struct MyPayload {
+  float4 color;
+  uint2 pos;
+};
+
+struct MyAttributes {
+  float2 bary;
+  uint id;
+};
+
+[shader("anyhit")] void anyhit1(inout MyPayload payload
+                                : SV_RayPayload,
+                                  in MyAttributes attr
+                                : SV_IntersectionAttributes) {
+  float3 hitLocation = ObjectRayOrigin() + ObjectRayDirection() * RayTCurrent();
+  if (hitLocation.z < attr.bary.x)
+    AcceptHitAndEndSearch(); // aborts function
+  if (hitLocation.z < attr.bary.y)
+    IgnoreHit(); // aborts function
+  if (GeometryIndex() == 0) {
+    payload.color += float4(0.125, 0.25, 0.5, 1.0);
+  } else {
+    payload.color += float4(0.2, 0.3, 0.3, 1.0);
+  }
+}

+ 30 - 0
tools/clang/test/CodeGenHLSL/batch/shader_stages/raytracing/raytracing_closesthit_geometryIndex.hlsl

@@ -0,0 +1,30 @@
+// RUN: %dxc -T lib_6_5 -auto-binding-space 11 %s | FileCheck %s
+
+// CHECK: define void [[closesthit1:@"\\01\?closesthit1@[^\"]+"]](%struct.MyPayload* noalias nocapture %payload, %struct.BuiltInTriangleIntersectionAttributes* nocapture readonly %attr) #0 {
+// CHECK:   [[GeometryIndex:%[^ ]+]] = call i32 @dx.op.geometryIndex.i32(i32 213)
+// CHECK:   icmp eq i32 [[GeometryIndex]], 0
+// CHECK:   call void @dx.op.callShader.struct.MyParam(i32 159, i32 {{.*}}, %struct.MyParam* nonnull {{.*}})
+// CHECK:   %[[color:[^ ]+]] = getelementptr inbounds %struct.MyPayload, %struct.MyPayload* %payload, i32 0, i32 0
+// CHECK:   store <4 x float> {{.*}}, <4 x float>* %[[color]], align 4
+// CHECK:   ret void
+
+struct MyPayload {
+  float4 color;
+  uint2 pos;
+};
+
+struct MyParam {
+  float2 coord;
+  float4 output;
+};
+
+[shader("closesthit")] void closesthit1(inout MyPayload payload
+                                        : SV_RayPayload,
+                                          in BuiltInTriangleIntersectionAttributes attr
+                                        : SV_IntersectionAttributes) {
+  MyParam param = {attr.barycentrics, {0, 0, 0, 0}};
+  if (GeometryIndex() == 0) {
+    CallShader(7, param);  
+  }
+  payload.color += param.output;
+}

+ 21 - 0
tools/clang/test/CodeGenHLSL/batch/shader_stages/raytracing/raytracing_intersection_geometryIndex.hlsl

@@ -0,0 +1,21 @@
+// RUN: %dxc -T lib_6_5 -auto-binding-space 11 %s | FileCheck %s
+
+// CHECK: define void [[intersection1:@"\\01\?intersection1@[^\"]+"]]() #0 {
+// CHECK:   [[rayTCurrent:%[^ ]+]] = call float @dx.op.rayTCurrent.f32(i32 154)
+// CHECK:   [[GeometryIndex:%[^ ]+]] = call i32 @dx.op.geometryIndex.i32(i32 213)
+// CHECK:   icmp eq i32 [[GeometryIndex]], 0
+// CHECK:   call i1 @dx.op.reportHit.struct.MyAttributes(i32 158, float [[rayTCurrent]], i32 0, %struct.MyAttributes* nonnull {{.*}})
+// CHECK:   ret void
+
+struct MyAttributes {
+  float2 bary;
+  uint id;
+};
+
+[shader("intersection")] void intersection1() {
+  float hitT = RayTCurrent();
+  MyAttributes attr = (MyAttributes)0;
+  if (GeometryIndex() == 0) {
+    bool bReported = ReportHit(hitT, 0, attr);  
+  }
+}

+ 22 - 0
tools/clang/test/CodeGenHLSL/mesh-val/amplification.hlsl

@@ -0,0 +1,22 @@
+// RUN: %dxc -E main -T as_6_5 %s | FileCheck %s
+
+// CHECK: dx.op.dispatchMesh.struct.Payload
+
+#define NUM_THREADS 32
+
+struct Payload {
+    float2 dummy;
+    float4 pos;
+    float color[2];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+void main()
+{
+    Payload pld;
+    pld.dummy = float2(1.0,2.0);
+    pld.pos = float4(3.0,4.0,5.0,6.0);
+    pld.color[0] = 7.0;
+    pld.color[1] = 8.0;
+    DispatchMesh(NUM_THREADS, 1, 1, pld);
+}

+ 22 - 0
tools/clang/test/CodeGenHLSL/mesh-val/asOversizePayload.hlsl

@@ -0,0 +1,22 @@
+// RUN: %dxc -E main -T as_6_5 %s | FileCheck %s
+
+// CHECK: payload size is greater than 16384
+
+#define NUM_THREADS 32
+
+struct Payload {
+    float2 dummy;
+    float4 pos[1024];
+    float color[2];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+void main()
+{
+    Payload pld;
+    pld.dummy = float2(1.0,2.0);
+    pld.pos[0] = float4(3.0,4.0,5.0,6.0);
+    pld.color[0] = 7.0;
+    pld.color[1] = 8.0;
+    DispatchMesh(NUM_THREADS, 1, 1, pld);
+}

+ 78 - 0
tools/clang/test/CodeGenHLSL/mesh-val/mesh.hlsl

@@ -0,0 +1,78 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: dx.op.getMeshPayload.struct.MeshPayload
+// CHECK: dx.op.setMeshOutputCounts(i32 168, i32 32, i32 16)
+// CHECK: dx.op.emitIndices
+// CHECK: dx.op.storeVertexOutput
+// CHECK: dx.op.storePrimitiveOutput
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    float alnorm : ALNORM;
+    float ormaln : ORMALN;
+    int layer[6] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor;
+    float alnorm;
+    float ormaln;
+    int layer[6];
+};
+
+groupshared float gsMem[MAX_PRIM];
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint3 primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+    MeshPerVertex ov;
+    if (vid % 2) {
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2);
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor = gsMem[tig / 3 + 1];
+      op.alnorm = mpl.alnorm;
+      op.ormaln = mpl.ormaln;
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl.layer[2];
+      op.layer[3] = mpl.layer[3];
+      op.layer[4] = mpl.layer[4];
+      op.layer[5] = mpl.layer[5];
+      gsMem[tig / 3] = op.normal;
+      prims[tig / 3] = op;
+    }
+    verts[tig] = ov;
+}

+ 21 - 0
tools/clang/test/CodeGenHLSL/mesh-val/missingDispatchMesh.hlsl

@@ -0,0 +1,21 @@
+// RUN: %dxc -E main -T as_6_5 %s | FileCheck %s
+
+// CHECK: DispatchMesh must be called exactly once in an Amplification shader.
+
+#define NUM_THREADS 32
+
+struct Payload {
+    float2 dummy;
+    float4 pos;
+    float color[2];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+void main()
+{
+    Payload pld;
+    pld.dummy = float2(1.0,2.0);
+    pld.pos = float4(3.0,4.0,5.0,6.0);
+    pld.color[0] = 7.0;
+    pld.color[1] = 8.0;
+}

+ 62 - 0
tools/clang/test/CodeGenHLSL/mesh-val/missingSetMeshOutputCounts.hlsl

@@ -0,0 +1,62 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: Missing SetMeshOutputCounts call.
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    int layer[4] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor;
+    int layer[4];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint3 primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    MeshPerVertex ov;
+    if (vid % 2) {
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2);
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor = mpl.malnor;
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl.layer[2];
+      op.layer[3] = mpl.layer[3];
+      prims[tig / 3] = op;
+    }
+    verts[tig] = ov;
+}

+ 63 - 0
tools/clang/test/CodeGenHLSL/mesh-val/msOversizePayload.hlsl

@@ -0,0 +1,63 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: payload size is greater than 16384
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    int layer[4] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor[4096];
+    int layer[4];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint3 primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+    MeshPerVertex ov;
+    if (vid % 2) {
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2);
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor = mpl.malnor[0];
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl.layer[2];
+      op.layer[3] = mpl.layer[3];
+      prims[tig / 3] = op;
+    }
+    verts[tig] = ov;
+}

+ 23 - 0
tools/clang/test/CodeGenHLSL/mesh-val/multipleDispatchMesh.hlsl

@@ -0,0 +1,23 @@
+// RUN: %dxc -E main -T as_6_5 %s | FileCheck %s
+
+// CHECK: DispatchMesh must be called exactly once in an Amplification shader.
+
+#define NUM_THREADS 32
+
+struct Payload {
+    float2 dummy;
+    float4 pos;
+    float color[2];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+void main()
+{
+    Payload pld;
+    pld.dummy = float2(1.0,2.0);
+    pld.pos = float4(3.0,4.0,5.0,6.0);
+    pld.color[0] = 7.0;
+    pld.color[1] = 8.0;
+    DispatchMesh(NUM_THREADS, 1, 1, pld);
+    DispatchMesh(NUM_THREADS, 1, 1, pld);
+}

+ 64 - 0
tools/clang/test/CodeGenHLSL/mesh-val/multipleSetMeshOutputCounts.hlsl

@@ -0,0 +1,64 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: SetMeshOUtputCounts cannot be called multiple times.
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    int layer[4] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor;
+    int layer[4];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint3 primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+    MeshPerVertex ov;
+    if (vid % 2) {
+        SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2);
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor = mpl.malnor;
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl.layer[2];
+      op.layer[3] = mpl.layer[3];
+      prims[tig / 3] = op;
+    }
+    verts[tig] = ov;
+}

+ 24 - 0
tools/clang/test/CodeGenHLSL/mesh-val/nonDominatingDispatchMesh.hlsl

@@ -0,0 +1,24 @@
+// RUN: %dxc -E main -T as_6_5 %s | FileCheck %s
+
+// CHECK: Non-Dominating DispatchMesh call.
+
+#define NUM_THREADS 32
+
+struct Payload {
+    float2 dummy;
+    float4 pos;
+    float color[2];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+void main(in uint tid : SV_DispatchThreadID)
+{
+    Payload pld;
+    pld.dummy = float2(1.0,2.0);
+    pld.pos = float4(3.0,4.0,5.0,6.0);
+    pld.color[0] = 7.0;
+    pld.color[1] = 8.0;
+    if (tid % 2) {
+      DispatchMesh(NUM_THREADS, 1, 1, pld);
+    }
+}

+ 63 - 0
tools/clang/test/CodeGenHLSL/mesh-val/nonDominatingSetMeshOutputCounts.hlsl

@@ -0,0 +1,63 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: Non-Dominating SetMeshOutputCounts call.
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    int layer[4] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor;
+    int layer[4];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint3 primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    MeshPerVertex ov;
+    if (vid % 2) {
+        SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2);
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor = mpl.malnor;
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl.layer[2];
+      op.layer[3] = mpl.layer[3];
+      prims[tig / 3] = op;
+    }
+    verts[tig] = ov;
+}

+ 74 - 0
tools/clang/test/CodeGenHLSL/mesh-val/oversizeSM.hlsl

@@ -0,0 +1,74 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: Total Thread Group Shared Memory storage is 28676, exceeded 28672
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    float alnorm : ALNORM;
+    float ormaln : ORMALN;
+    int layer[6] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor;
+    float alnorm;
+    float ormaln;
+    int layer[6];
+};
+
+groupshared float gsMem[1024 * 7 + 1];
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint3 primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+    MeshPerVertex ov;
+    if (vid % 2) {
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2);
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor = gsMem[tig / 3 + 1];
+      op.alnorm = mpl.alnorm;
+      op.ormaln = mpl.ormaln;
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl.layer[2];
+      op.layer[3] = mpl.layer[3];
+      op.layer[4] = mpl.layer[4];
+      op.layer[5] = mpl.layer[5];
+      gsMem[tig / 3] = op.normal;
+      prims[tig / 3] = op;
+    }
+    verts[tig] = ov;
+}

+ 22 - 0
tools/clang/test/CodeGenHLSL/mesh/amplification.hlsl

@@ -0,0 +1,22 @@
+// RUN: %dxc -E main -T as_6_5 %s | FileCheck %s
+
+// CHECK: dx.op.dispatchMesh.struct.Payload
+
+#define NUM_THREADS 32
+
+struct Payload {
+    float2 dummy;
+    float4 pos;
+    float color[2];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+void main()
+{
+    Payload pld;
+    pld.dummy = float2(1.0,2.0);
+    pld.pos = float4(3.0,4.0,5.0,6.0);
+    pld.color[0] = 7.0;
+    pld.color[1] = 8.0;
+    DispatchMesh(NUM_THREADS, 1, 1, pld);
+}

+ 64 - 0
tools/clang/test/CodeGenHLSL/mesh/illegalOutIndicesAssignment.hlsl

@@ -0,0 +1,64 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: error: a vector in out indices array must be accessed as a whole
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    int layer[4] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor;
+    int layer[4];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint3 primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+    MeshPerVertex ov;
+    if (vid % 2) {
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2);
+      primIndices[tig / 3].x = 0;
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor = mpl.malnor;
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl.layer[2];
+      op.layer[3] = mpl.layer[3];
+      prims[tig / 3] = op;
+    }
+    verts[tig] = ov;
+}

+ 78 - 0
tools/clang/test/CodeGenHLSL/mesh/mesh.hlsl

@@ -0,0 +1,78 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: dx.op.getMeshPayload.struct.MeshPayload
+// CHECK: dx.op.setMeshOutputCounts(i32 168, i32 32, i32 16)
+// CHECK: dx.op.emitIndices
+// CHECK: dx.op.storeVertexOutput
+// CHECK: dx.op.storePrimitiveOutput
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    float alnorm : ALNORM;
+    float ormaln : ORMALN;
+    int layer[6] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor;
+    float alnorm;
+    float ormaln;
+    int layer[6];
+};
+
+groupshared float gsMem[MAX_PRIM];
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint3 primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+    MeshPerVertex ov;
+    if (vid % 2) {
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2);
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor = gsMem[tig / 3 + 1];
+      op.alnorm = mpl.alnorm;
+      op.ormaln = mpl.ormaln;
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl.layer[2];
+      op.layer[3] = mpl.layer[3];
+      op.layer[4] = mpl.layer[4];
+      op.layer[5] = mpl.layer[5];
+      gsMem[tig / 3] = op.normal;
+      prims[tig / 3] = op;
+    }
+    verts[tig] = ov;
+}

+ 64 - 0
tools/clang/test/CodeGenHLSL/mesh/multipleInPayload.hlsl

@@ -0,0 +1,64 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: error: multiple in payload parameters not allowed
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    int layer[4] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor;
+    int layer[4];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint3 primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in payload MeshPayload mpl2,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+    MeshPerVertex ov;
+    if (vid % 2) {
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2);
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor = mpl2.malnor;
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl2.layer[2];
+      op.layer[3] = mpl2.layer[3];
+      prims[tig / 3] = op;
+    }
+    verts[tig] = ov;
+}

Nem az összes módosított fájl került megjelenítésre, mert túl sok fájl változott