Selaa lähdekoodia

Merged PR 116: Add support for HLSL Meshlets

This PR adds support for new HLSL mesh and amplification shaders.
Sahil Parmar 6 vuotta sitten
vanhempi
commit
968fe41136
100 muutettua tiedostoa jossa 5207 lisäystä ja 769 poistoa
  1. 95 65
      docs/DXIL.rst
  2. 124 18
      docs/SPIR-V.rst
  3. 1 1
      external/SPIRV-Headers
  4. 1 1
      external/SPIRV-Tools
  5. 73 12
      include/dxc/DXIL/DxilConstants.h
  6. 17 1
      include/dxc/DXIL/DxilFunctionProps.h
  7. 188 0
      include/dxc/DXIL/DxilInstructions.h
  8. 26 0
      include/dxc/DXIL/DxilMetadataHelper.h
  9. 12 2
      include/dxc/DXIL/DxilModule.h
  10. 3 1
      include/dxc/DXIL/DxilShaderModel.h
  11. 1 1
      include/dxc/DXIL/DxilSigPoint.h
  12. 78 52
      include/dxc/DXIL/DxilSigPoint.inl
  13. 4 2
      include/dxc/DXIL/DxilSignature.h
  14. 1 1
      include/dxc/DXIL/DxilSignatureElement.h
  15. 4 0
      include/dxc/DXIL/DxilTypeSystem.h
  16. 91 74
      include/dxc/DxilContainer/DxilPipelineStateValidation.h
  17. 3 3
      include/dxc/HLSL/ComputeViewIdState.h
  18. 17 0
      include/dxc/HLSL/DxilValidation.h
  19. 6 0
      include/dxc/HLSL/HLOperations.h
  20. 4 4
      include/dxc/HLSL/ViewIDPipelineValidation.inl
  21. 2 0
      include/dxc/HlslIntrinsicOp.h
  22. 1 1
      include/dxc/Support/HLSLOptions.td
  23. 1 0
      include/dxc/Support/SPIRVOptions.h
  24. 128 5
      lib/DXIL/DxilMetadataHelper.cpp
  25. 74 4
      lib/DXIL/DxilModule.cpp
  26. 42 4
      lib/DXIL/DxilOperations.cpp
  27. 1 0
      lib/DXIL/DxilSemantic.cpp
  28. 8 3
      lib/DXIL/DxilShaderModel.cpp
  29. 12 1
      lib/DXIL/DxilSignature.cpp
  30. 2 2
      lib/DXIL/DxilSignatureElement.cpp
  31. 22 0
      lib/DXIL/DxilTypeSystem.cpp
  32. 71 25
      lib/DxilContainer/DxilContainerAssembler.cpp
  33. 1 1
      lib/DxilPIXPasses/DxilShaderAccessTracking.cpp
  34. 46 33
      lib/HLSL/ComputeViewIdState.cpp
  35. 36 17
      lib/HLSL/ComputeViewIdStateBuilder.cpp
  36. 3 3
      lib/HLSL/DxilContainerReflection.cpp
  37. 16 5
      lib/HLSL/DxilPreserveAllOutputs.cpp
  38. 429 41
      lib/HLSL/DxilValidation.cpp
  39. 2 0
      lib/HLSL/HLModule.cpp
  40. 32 0
      lib/HLSL/HLOperationLower.cpp
  41. 208 62
      lib/HLSL/HLSignatureLower.cpp
  42. 9 1
      lib/HLSL/HLSignatureLower.h
  43. 6 2
      lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp
  44. 24 0
      tools/clang/include/clang/Basic/Attr.td
  45. 10 4
      tools/clang/include/clang/Basic/DiagnosticSemaKinds.td
  46. 4 0
      tools/clang/include/clang/Basic/TokenKinds.def
  47. 1 0
      tools/clang/include/clang/SPIRV/FeatureManager.h
  48. 7 0
      tools/clang/include/clang/SPIRV/SpirvBuilder.h
  49. 4 0
      tools/clang/include/clang/SPIRV/SpirvContext.h
  50. 222 14
      tools/clang/lib/CodeGen/CGHLSLMS.cpp
  51. 21 0
      tools/clang/lib/Parse/ParseDecl.cpp
  52. 10 0
      tools/clang/lib/Parse/ParseExpr.cpp
  53. 5 1
      tools/clang/lib/Parse/ParseStmt.cpp
  54. 4 0
      tools/clang/lib/Parse/ParseTentative.cpp
  55. 14 6
      tools/clang/lib/SPIRV/CapabilityVisitor.cpp
  56. 243 49
      tools/clang/lib/SPIRV/DeclResultIdMapper.cpp
  57. 20 5
      tools/clang/lib/SPIRV/DeclResultIdMapper.h
  58. 3 0
      tools/clang/lib/SPIRV/FeatureManager.cpp
  59. 68 41
      tools/clang/lib/SPIRV/GlPerVertex.cpp
  60. 5 4
      tools/clang/lib/SPIRV/GlPerVertex.h
  61. 17 0
      tools/clang/lib/SPIRV/SpirvBuilder.cpp
  62. 528 43
      tools/clang/lib/SPIRV/SpirvEmitter.cpp
  63. 33 2
      tools/clang/lib/SPIRV/SpirvEmitter.h
  64. 32 0
      tools/clang/lib/Sema/SemaExpr.cpp
  65. 3 1
      tools/clang/lib/Sema/SemaExprCXX.cpp
  66. 141 8
      tools/clang/lib/Sema/SemaHLSL.cpp
  67. 158 142
      tools/clang/lib/Sema/gen_intrin_main_tables_15.h
  68. 1 1
      tools/clang/test/CodeGenHLSL/batch/expressions/intrinsics/misc/abs1.hlsl
  69. 22 0
      tools/clang/test/CodeGenHLSL/mesh-val/amplification.hlsl
  70. 22 0
      tools/clang/test/CodeGenHLSL/mesh-val/asOversizePayload.hlsl
  71. 78 0
      tools/clang/test/CodeGenHLSL/mesh-val/mesh.hlsl
  72. 21 0
      tools/clang/test/CodeGenHLSL/mesh-val/missingDispatchMesh.hlsl
  73. 62 0
      tools/clang/test/CodeGenHLSL/mesh-val/missingSetMeshOutputCounts.hlsl
  74. 63 0
      tools/clang/test/CodeGenHLSL/mesh-val/msOversizePayload.hlsl
  75. 23 0
      tools/clang/test/CodeGenHLSL/mesh-val/multipleDispatchMesh.hlsl
  76. 64 0
      tools/clang/test/CodeGenHLSL/mesh-val/multipleSetMeshOutputCounts.hlsl
  77. 24 0
      tools/clang/test/CodeGenHLSL/mesh-val/nonDominatingDispatchMesh.hlsl
  78. 63 0
      tools/clang/test/CodeGenHLSL/mesh-val/nonDominatingSetMeshOutputCounts.hlsl
  79. 74 0
      tools/clang/test/CodeGenHLSL/mesh-val/oversizeSM.hlsl
  80. 22 0
      tools/clang/test/CodeGenHLSL/mesh/amplification.hlsl
  81. 64 0
      tools/clang/test/CodeGenHLSL/mesh/illegalOutIndicesAssignment.hlsl
  82. 78 0
      tools/clang/test/CodeGenHLSL/mesh/mesh.hlsl
  83. 64 0
      tools/clang/test/CodeGenHLSL/mesh/multipleInPayload.hlsl
  84. 65 0
      tools/clang/test/CodeGenHLSL/mesh/multipleOutIndices.hlsl
  85. 65 0
      tools/clang/test/CodeGenHLSL/mesh/multipleOutPrimitives.hlsl
  86. 65 0
      tools/clang/test/CodeGenHLSL/mesh/multipleOutVertices.hlsl
  87. 63 0
      tools/clang/test/CodeGenHLSL/mesh/notArrayOutIndices.hlsl
  88. 63 0
      tools/clang/test/CodeGenHLSL/mesh/notArrayOutPrimitives.hlsl
  89. 63 0
      tools/clang/test/CodeGenHLSL/mesh/notArrayOutVertices.hlsl
  90. 63 0
      tools/clang/test/CodeGenHLSL/mesh/notUint2OutIndicesForLines.hlsl
  91. 63 0
      tools/clang/test/CodeGenHLSL/mesh/notUint3OutIndicesForTriangles.hlsl
  92. 63 0
      tools/clang/test/CodeGenHLSL/mesh/notUintOutIndices.hlsl
  93. 63 0
      tools/clang/test/CodeGenHLSL/mesh/notVectorOutIndices.hlsl
  94. 64 0
      tools/clang/test/CodeGenHLSL/mesh/readFromOutIndices.hlsl
  95. 63 0
      tools/clang/test/CodeGenHLSL/mesh/tooManyOutIndices.hlsl
  96. 63 0
      tools/clang/test/CodeGenHLSL/mesh/tooManyOutPrimitives.hlsl
  97. 63 0
      tools/clang/test/CodeGenHLSL/mesh/tooManyOutVertices.hlsl
  98. 62 0
      tools/clang/test/CodeGenSPIRV/meshshading.nv.amplification.hlsl
  99. 14 0
      tools/clang/test/CodeGenSPIRV/meshshading.nv.error1.amplification.hlsl
  100. 19 0
      tools/clang/test/CodeGenSPIRV/meshshading.nv.error1.mesh.hlsl

+ 95 - 65
docs/DXIL.rst

@@ -90,17 +90,19 @@ The shader model is specified as a named metadata in DXIL::
 
 The following values of <shaderModelName>_<major>_<minor> are supported:
 
-==================== ===================================== ===========
-Target               Legacy Models                         DXIL Models
-==================== ===================================== ===========
-Vertex shader (VS)   vs_4_0, vs_4_1, vs_5_0, vs_5_1        vs_6_0
-Hull shader (HS)     hs_5_0, hs_5_1                        hs_6_0
-Domain shader (DS)   ds_5_0, ds_5_1                        ds_6_0
-Geometry shader (GS) gs_4_0, gs_4_1, gs_5_0, gs_5_1        gs_6_0
-Pixel shader (PS)    ps_4_0, ps_4_1, ps_5_0, ps_5_1        ps_6_0
-Compute shader (CS)  cs_5_0 (cs_4_0 is mapped onto cs_5_0) cs_6_0
-Shader library       no support                            lib_6_1
-==================== ===================================== ===========
+====================      ===================================== ===========
+Target                    Legacy Models                         DXIL Models
+====================      ===================================== ===========
+Vertex shader (VS)        vs_4_0, vs_4_1, vs_5_0, vs_5_1        vs_6_0
+Hull shader (HS)          hs_5_0, hs_5_1                        hs_6_0
+Domain shader (DS)        ds_5_0, ds_5_1                        ds_6_0
+Geometry shader (GS)      gs_4_0, gs_4_1, gs_5_0, gs_5_1        gs_6_0
+Pixel shader (PS)         ps_4_0, ps_4_1, ps_5_0, ps_5_1        ps_6_0
+Compute shader (CS)       cs_5_0 (cs_4_0 is mapped onto cs_5_0) cs_6_0
+Shader library            no support                            lib_6_1
+Mesh shader (MS)          no support                            ms_6_5
+Amplification shader (AS) no support                            as_6_5
+========================= ===================================== ===========
 
 The DXIL verifier ensures that DXIL conforms to the specified shader model.
 
@@ -609,26 +611,30 @@ Signature Points are enumerated as follows in the SigPointKind
 .. <py::lines('SIGPOINT-RST')>hctdb_instrhelp.get_sigpoint_rst()</py>
 .. SIGPOINT-RST:BEGIN
 
-== ======== ======= ========== ============== ============= ============================================================================
-ID SigPoint Related ShaderKind PackingKind    SignatureKind Description
-== ======== ======= ========== ============== ============= ============================================================================
-0  VSIn     Invalid Vertex     InputAssembler Input         Ordinary Vertex Shader input from Input Assembler
-1  VSOut    Invalid Vertex     Vertex         Output        Ordinary Vertex Shader output that may feed Rasterizer
-2  PCIn     HSCPIn  Hull       None           Invalid       Patch Constant function non-patch inputs
-3  HSIn     HSCPIn  Hull       None           Invalid       Hull Shader function non-patch inputs
-4  HSCPIn   Invalid Hull       Vertex         Input         Hull Shader patch inputs - Control Points
-5  HSCPOut  Invalid Hull       Vertex         Output        Hull Shader function output - Control Point
-6  PCOut    Invalid Hull       PatchConstant  PatchConstant Patch Constant function output - Patch Constant data passed to Domain Shader
-7  DSIn     Invalid Domain     PatchConstant  PatchConstant Domain Shader regular input - Patch Constant data plus system values
-8  DSCPIn   Invalid Domain     Vertex         Input         Domain Shader patch input - Control Points
-9  DSOut    Invalid Domain     Vertex         Output        Domain Shader output - vertex data that may feed Rasterizer
-10 GSVIn    Invalid Geometry   Vertex         Input         Geometry Shader vertex input - qualified with primitive type
-11 GSIn     GSVIn   Geometry   None           Invalid       Geometry Shader non-vertex inputs (system values)
-12 GSOut    Invalid Geometry   Vertex         Output        Geometry Shader output - vertex data that may feed Rasterizer
-13 PSIn     Invalid Pixel      Vertex         Input         Pixel Shader input
-14 PSOut    Invalid Pixel      Target         Output        Pixel Shader output
-15 CSIn     Invalid Compute    None           Invalid       Compute Shader input
-== ======== ======= ========== ============== ============= ============================================================================
+== ======== ======= ============= ============== ================ ============================================================================
+ID SigPoint Related ShaderKind    PackingKind    SignatureKind    Description
+== ======== ======= ============= ============== ================ ============================================================================
+0  VSIn     Invalid Vertex        InputAssembler Input            Ordinary Vertex Shader input from Input Assembler
+1  VSOut    Invalid Vertex        Vertex         Output           Ordinary Vertex Shader output that may feed Rasterizer
+2  PCIn     HSCPIn  Hull          None           Invalid          Patch Constant function non-patch inputs
+3  HSIn     HSCPIn  Hull          None           Invalid          Hull Shader function non-patch inputs
+4  HSCPIn   Invalid Hull          Vertex         Input            Hull Shader patch inputs - Control Points
+5  HSCPOut  Invalid Hull          Vertex         Output           Hull Shader function output - Control Point
+6  PCOut    Invalid Hull          PatchConstant  PatchConstOrPrim Patch Constant function output - Patch Constant data passed to Domain Shader
+7  DSIn     Invalid Domain        PatchConstant  PatchConstOrPrim Domain Shader regular input - Patch Constant data plus system values
+8  DSCPIn   Invalid Domain        Vertex         Input            Domain Shader patch input - Control Points
+9  DSOut    Invalid Domain        Vertex         Output           Domain Shader output - vertex data that may feed Rasterizer
+10 GSVIn    Invalid Geometry      Vertex         Input            Geometry Shader vertex input - qualified with primitive type
+11 GSIn     GSVIn   Geometry      None           Invalid          Geometry Shader non-vertex inputs (system values)
+12 GSOut    Invalid Geometry      Vertex         Output           Geometry Shader output - vertex data that may feed Rasterizer
+13 PSIn     Invalid Pixel         Vertex         Input            Pixel Shader input
+14 PSOut    Invalid Pixel         Target         Output           Pixel Shader output
+15 CSIn     Invalid Compute       None           Invalid          Compute Shader input
+16 MSIn     Invalid Mesh          None           Invalid          Mesh Shader input
+17 MSOut    Invalid Mesh          Vertex         Output           Mesh Shader vertices output
+18 MSPOut   Invalid Mesh          Vertex         PatchConstOrPrim Mesh Shader primitives output
+19 ASIn     Invalid Amplification None           Invalid          Amplification Shader input
+== ======== ======= ============= ============== ================ ============================================================================
 
 .. SIGPOINT-RST:END
 
@@ -663,40 +669,41 @@ Semantic Interpretations for each SemanticKind at each SigPointKind are as follo
 .. <py::lines('SEMINT-TABLE-RST')>hctdb_instrhelp.get_sem_interpretation_table_rst()</py>
 .. SEMINT-TABLE-RST:BEGIN
 
-====================== ============ ======== ============ ============ ======== ======== ========== ============ ======== ======== ======== ============ ======== ============= ============= ========
-Semantic               VSIn         VSOut    PCIn         HSIn         HSCPIn   HSCPOut  PCOut      DSIn         DSCPIn   DSOut    GSVIn    GSIn         GSOut    PSIn          PSOut         CSIn
-====================== ============ ======== ============ ============ ======== ======== ========== ============ ======== ======== ======== ============ ======== ============= ============= ========
-Arbitrary              Arb          Arb      NA           NA           Arb      Arb      Arb        Arb          Arb      Arb      Arb      NA           Arb      Arb           NA            NA
-VertexID               SV           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NA            NA
-InstanceID             SV           Arb      NA           NA           Arb      Arb      NA         NA           Arb      Arb      Arb      NA           Arb      Arb           NA            NA
-Position               Arb          SV       NA           NA           SV       SV       Arb        Arb          SV       SV       SV       NA           SV       SV            NA            NA
-RenderTargetArrayIndex Arb          SV       NA           NA           SV       SV       Arb        Arb          SV       SV       SV       NA           SV       SV            NA            NA
-ViewPortArrayIndex     Arb          SV       NA           NA           SV       SV       Arb        Arb          SV       SV       SV       NA           SV       SV            NA            NA
-ClipDistance           Arb          ClipCull NA           NA           ClipCull ClipCull Arb        Arb          ClipCull ClipCull ClipCull NA           ClipCull ClipCull      NA            NA
-CullDistance           Arb          ClipCull NA           NA           ClipCull ClipCull Arb        Arb          ClipCull ClipCull ClipCull NA           ClipCull ClipCull      NA            NA
-OutputControlPointID   NA           NA       NA           NotInSig     NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NA            NA
-DomainLocation         NA           NA       NA           NA           NA       NA       NA         NotInSig     NA       NA       NA       NA           NA       NA            NA            NA
-PrimitiveID            NA           NA       NotInSig     NotInSig     NA       NA       NA         NotInSig     NA       NA       NA       Shadow       SGV      SGV           NA            NA
-GSInstanceID           NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NotInSig     NA       NA            NA            NA
-SampleIndex            NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       Shadow _41    NA            NA
-IsFrontFace            NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           SGV      SGV           NA            NA
-Coverage               NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NotInSig _50  NotPacked _41 NA
-InnerCoverage          NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NotInSig _50  NA            NA
-Target                 NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            Target        NA
-Depth                  NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NotPacked     NA
-DepthLessEqual         NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NotPacked _50 NA
-DepthGreaterEqual      NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NotPacked _50 NA
-StencilRef             NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NotPacked _50 NA
-DispatchThreadID       NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NA            NotInSig
-GroupID                NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NA            NotInSig
-GroupIndex             NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NA            NotInSig
-GroupThreadID          NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NA            NotInSig
-TessFactor             NA           NA       NA           NA           NA       NA       TessFactor TessFactor   NA       NA       NA       NA           NA       NA            NA            NA
-InsideTessFactor       NA           NA       NA           NA           NA       NA       TessFactor TessFactor   NA       NA       NA       NA           NA       NA            NA            NA
-ViewID                 NotInSig _61 NA       NotInSig _61 NotInSig _61 NA       NA       NA         NotInSig _61 NA       NA       NA       NotInSig _61 NA       NotInSig _61  NA            NA
-Barycentrics           NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NotPacked _61 NA            NA
-ShadingRate            NA           SV _64   NA           NA           SV _64   SV _64   NA         NA           SV _64   SV _64   SV _64   NA           SV _64   SV _64        NA            NA
-====================== ============ ======== ============ ============ ======== ======== ========== ============ ======== ======== ======== ============ ======== ============= ============= ========
+====================== ============ ======== ============ ============ ======== ======== ========== ============ ======== ======== ======== ============ ======== ============= ============= ======== ============ ============ ======= ============
+Semantic               VSIn         VSOut    PCIn         HSIn         HSCPIn   HSCPOut  PCOut      DSIn         DSCPIn   DSOut    GSVIn    GSIn         GSOut    PSIn          PSOut         CSIn     MSIn         MSOut        MSPOut  ASIn
+====================== ============ ======== ============ ============ ======== ======== ========== ============ ======== ======== ======== ============ ======== ============= ============= ======== ============ ============ ======= ============
+Arbitrary              Arb          Arb      NA           NA           Arb      Arb      Arb        Arb          Arb      Arb      Arb      NA           Arb      Arb           NA            NA       NA           Arb _65      Arb _65 NA
+VertexID               SV           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NA            NA       NA           NA           NA      NA
+InstanceID             SV           Arb      NA           NA           Arb      Arb      NA         NA           Arb      Arb      Arb      NA           Arb      Arb           NA            NA       NA           NA           NA      NA
+Position               Arb          SV       NA           NA           SV       SV       Arb        Arb          SV       SV       SV       NA           SV       SV            NA            NA       NA           SV _65       NA      NA
+RenderTargetArrayIndex Arb          SV       NA           NA           SV       SV       Arb        Arb          SV       SV       SV       NA           SV       SV            NA            NA       NA           NA           SV _65  NA
+ViewPortArrayIndex     Arb          SV       NA           NA           SV       SV       Arb        Arb          SV       SV       SV       NA           SV       SV            NA            NA       NA           NA           SV _65  NA
+ClipDistance           Arb          ClipCull NA           NA           ClipCull ClipCull Arb        Arb          ClipCull ClipCull ClipCull NA           ClipCull ClipCull      NA            NA       NA           ClipCull _65 NA      NA
+CullDistance           Arb          ClipCull NA           NA           ClipCull ClipCull Arb        Arb          ClipCull ClipCull ClipCull NA           ClipCull ClipCull      NA            NA       NA           ClipCull _65 NA      NA
+OutputControlPointID   NA           NA       NA           NotInSig     NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NA            NA       NA           NA           NA      NA
+DomainLocation         NA           NA       NA           NA           NA       NA       NA         NotInSig     NA       NA       NA       NA           NA       NA            NA            NA       NA           NA           NA      NA
+PrimitiveID            NA           NA       NotInSig     NotInSig     NA       NA       NA         NotInSig     NA       NA       NA       Shadow       SGV      SGV           NA            NA       NA           NA           SV _65  NA
+GSInstanceID           NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NotInSig     NA       NA            NA            NA       NA           NA           NA      NA
+SampleIndex            NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       Shadow _41    NA            NA       NA           NA           NA      NA
+IsFrontFace            NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           SGV      SGV           NA            NA       NA           NA           NA      NA
+Coverage               NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NotInSig _50  NotPacked _41 NA       NA           NA           NA      NA
+InnerCoverage          NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NotInSig _50  NA            NA       NA           NA           NA      NA
+Target                 NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            Target        NA       NA           NA           NA      NA
+Depth                  NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NotPacked     NA       NA           NA           NA      NA
+DepthLessEqual         NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NotPacked _50 NA       NA           NA           NA      NA
+DepthGreaterEqual      NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NotPacked _50 NA       NA           NA           NA      NA
+StencilRef             NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NotPacked _50 NA       NA           NA           NA      NA
+DispatchThreadID       NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NA            NotInSig NotInSig _65 NA           NA      NotInSig _65
+GroupID                NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NA            NotInSig NotInSig _65 NA           NA      NotInSig _65
+GroupIndex             NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NA            NotInSig NotInSig _65 NA           NA      NotInSig _65
+GroupThreadID          NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NA            NotInSig NotInSig _65 NA           NA      NotInSig _65
+TessFactor             NA           NA       NA           NA           NA       NA       TessFactor TessFactor   NA       NA       NA       NA           NA       NA            NA            NA       NA           NA           NA      NA
+InsideTessFactor       NA           NA       NA           NA           NA       NA       TessFactor TessFactor   NA       NA       NA       NA           NA       NA            NA            NA       NA           NA           NA      NA
+ViewID                 NotInSig _61 NA       NotInSig _61 NotInSig _61 NA       NA       NA         NotInSig _61 NA       NA       NA       NotInSig _61 NA       NotInSig _61  NA            NA       NotInSig _65 NA           NA      NA
+Barycentrics           NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NotPacked _61 NA            NA       NA           NA           NA      NA
+ShadingRate            NA           SV _64   NA           NA           SV _64   SV _64   NA         NA           SV _64   SV _64   SV _64   NA           SV _64   SV _64        NA            NA       NA           NA           NA      NA
+CullPrimitive          NA           NA       NA           NA           NA       NA       NA         NA           NA       NA       NA       NA           NA       NA            NA            NA       NA           NA           SV _65  NA
+====================== ============ ======== ============ ============ ======== ======== ========== ============ ======== ======== ======== ============ ======== ============= ============= ======== ============ ============ ======= ============
 
 .. SEMINT-TABLE-RST:END
 
@@ -2256,6 +2263,12 @@ ID  Name                          Description
 165 WaveMatch                     returns the bitmask of active lanes that have the same value
 166 WaveMultiPrefixOp             returns the result of the operation on groups of lanes identified by a bitmask
 167 WaveMultiPrefixBitCount       returns the count of bits set to 1 on groups of lanes identified by a bitmask
+168 SetMeshOutputCounts           Mesh shader intrinsic SetMeshOutputCounts
+169 EmitIndices                   emit a primitive's vertex indices in a mesh shader
+170 GetMeshPayload                get the mesh payload which is from amplification shader
+171 StoreVertexOutput             stores the value to mesh shader vertex output
+172 StorePrimitiveOutput          stores the value to mesh shader primitive output
+173 DispatchMesh                  Amplification shader intrinsic DispatchMesh
 === ============================= =======================================================================================================================================================================================================================
 
 
@@ -2957,13 +2970,19 @@ INSTR.MINPRECISIONNOTPRECISE             Instructions marked precise may not ref
 INSTR.MINPRECISONBITCAST                 Bitcast on minprecison types is not allowed
 INSTR.MIPLEVELFORGETDIMENSION            Use mip level on buffer when GetDimensions
 INSTR.MIPONUAVLOAD                       uav load don't support mipLevel/sampleIndex
+INSTR.MISSINGSETMESHOUTPUTCOUNTS         Missing SetMeshOutputCounts call.
+INSTR.MULTIPLEGETMESHPAYLOAD             GetMeshPayload cannot be called multiple times.
+INSTR.MULTIPLESETMESHOUTPUTCOUNTS        SetMeshOUtputCounts cannot be called multiple times.
 INSTR.NOGENERICPTRADDRSPACECAST          Address space cast between pointer types must have one part to be generic address space
 INSTR.NOIDIVBYZERO                       No signed integer division by zero
 INSTR.NOINDEFINITEACOS                   No indefinite arccosine
 INSTR.NOINDEFINITEASIN                   No indefinite arcsine
 INSTR.NOINDEFINITEDSXY                   No indefinite derivative calculation
 INSTR.NOINDEFINITELOG                    No indefinite logarithm
+INSTR.NONDOMINATINGDISPATCHMESH          Non-Dominating DispatchMesh call.
+INSTR.NONDOMINATINGSETMESHOUTPUTCOUNTS   Non-Dominating SetMeshOutputCounts call.
 INSTR.NOREADINGUNINITIALIZED             Instructions should not read uninitialized value
+INSTR.NOTONCEDISPATCHMESH                DispatchMesh must be called exactly once in an Amplification shader.
 INSTR.NOUDIVBYZERO                       No unsigned integer division by zero
 INSTR.OFFSETONUAVLOAD                    uav load don't support offset
 INSTR.OLOAD                              DXIL intrinsic overload must be valid
@@ -3052,12 +3071,14 @@ META.VALIDSAMPLERMODE                    Invalid sampler mode on sampler
 META.VALUERANGE                          Metadata value must be within range
 META.WELLFORMED                          TODO - Metadata must be well-formed in operand count and types
 SM.64BITRAWBUFFERLOADSTORE               i64/f64 rawBufferLoad/Store overloads are allowed after SM 6.3
+SM.AMPLIFICATIONSHADERPAYLOADSIZE        For shader '%0', payload size is greater than %1
 SM.APPENDANDCONSUMEONSAMEUAV             BufferUpdateCounter inc and dec on a given UAV (%d) cannot both be in the same shader for shader model less than 5.1.
 SM.CBUFFERARRAYOFFSETALIGNMENT           CBuffer array offset must be aligned to 16-bytes
 SM.CBUFFERELEMENTOVERFLOW                CBuffer elements must not overflow
 SM.CBUFFEROFFSETOVERLAP                  CBuffer offsets must not overlap
 SM.CBUFFERTEMPLATETYPEMUSTBESTRUCT       D3D12 constant/texture buffer template element can only be a struct
 SM.COMPLETEPOSITION                      Not all elements of SV_Position were written
+SM.CONSTANTINTERPMODE                    Interpolation mode must be constant for MS primitive output.
 SM.COUNTERONLYONSTRUCTBUF                BufferUpdateCounter valid only on structured buffers
 SM.CSNOSIGNATURES                        Compute shaders must not have shader signatures.
 SM.DOMAINLOCATIONIDXOOB                  DomainLocation component index out of bounds for the domain.
@@ -3075,8 +3096,17 @@ SM.INVALIDRESOURCECOMPTYPE               Invalid resource return type
 SM.INVALIDRESOURCEKIND                   Invalid resources kind
 SM.INVALIDTEXTUREKINDONUAV               Texture2DMS[Array] or TextureCube[Array] resources are not supported with UAVs
 SM.ISOLINEOUTPUTPRIMITIVEMISMATCH        Hull Shader declared with IsoLine Domain must specify output primitive point or line. Triangle_cw or triangle_ccw output are not compatible with the IsoLine Domain.
+SM.MAXMSSMSIZE                           Total Thread Group Shared Memory storage is %0, exceeded %1
 SM.MAXTGSMSIZE                           Total Thread Group Shared Memory storage is %0, exceeded %1
 SM.MAXTHEADGROUP                         Declared Thread Group Count %0 (X*Y*Z) is beyond the valid maximum of %1
+SM.MESHPSIGROWCOUNT                      For shader '%0', primitive output signatures are taking up more than %1 rows
+SM.MESHSHADERINOUTSIZE                   For shader '%0', input plus output size is greater than %1
+SM.MESHSHADERMAXPRIMITIVECOUNT           MS max primitive output count must be [0..%0].  %1 specified
+SM.MESHSHADERMAXVERTEXCOUNT              MS max vertex output count must be [0..%0].  %1 specified
+SM.MESHSHADEROUTPUTSIZE                  For shader '%0', vertex plus primitive output size is greater than %1
+SM.MESHSHADERPAYLOADSIZE                 For shader '%0', payload size is greater than %1
+SM.MESHTOTALSIGROWCOUNT                  For shader '%0', vertex and primitive output signatures are taking up more than %1 rows
+SM.MESHVSIGROWCOUNT                      For shader '%0', vertex output signatures are taking up more than %1 rows
 SM.MULTISTREAMMUSTBEPOINT                When multiple GS output streams are used they must be pointlists
 SM.NAME                                  Target shader model name must be known
 SM.NOINTERPMODE                          Interpolation mode must be undefined for VS input/PS output/patch constant.

+ 124 - 18
docs/SPIR-V.rst

@@ -264,6 +264,7 @@ Right now the following ``<builtin>`` are supported:
   Need ``SPV_KHR_shader_draw_parameters`` extension.
 * ``DeviceIndex``: The GLSL equivalent is ``gl_DeviceIndex``.
   Need ``SPV_KHR_device_group`` extension.
+* ``ViewportMaskNV``: The GLSL equivalent is ``gl_ViewportMask``.
 
 Please see Vulkan spec. `14.6. Built-In Variables <https://www.khronos.org/registry/vulkan/specs/1.1-extensions/html/vkspec.html#interfaces-builtin-variables>`_
 for detailed explanation of these builtins.
@@ -282,6 +283,7 @@ Supported extensions
 * SPV_EXT_shader_stencil_support
 * SPV_AMD_shader_explicit_vertex_parameter
 * SPV_GOOGLE_hlsl_functionality1
+* SPV_NV_mesh_shader
 
 Vulkan specific attributes
 --------------------------
@@ -1272,14 +1274,16 @@ some system-value (SV) semantic strings will be translated into SPIR-V
 |                           | HSCPOut     | ``Position``             | N/A                   | ``Shader``                  |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | DSCPIn      | ``Position``             | N/A                   | ``Shader``                  |
-| SV_Position               +-------------+--------------------------+-----------------------+-----------------------------+
-|                           | DSOut       | ``Position``             | N/A                   | ``Shader``                  |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+| SV_Position               | DSOut       | ``Position``             | N/A                   | ``Shader``                  |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | GSVIn       | ``Position``             | N/A                   | ``Shader``                  |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | GSOut       | ``Position``             | N/A                   | ``Shader``                  |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | PSIn        | ``FragCoord``            | N/A                   | ``Shader``                  |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+|                           | MSOut       | ``Position``             | N/A                   | ``Shader``                  |
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
 |                           | VSOut       | ``ClipDistance``         | N/A                   | ``ClipDistance``            |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
@@ -1288,14 +1292,16 @@ some system-value (SV) semantic strings will be translated into SPIR-V
 |                           | HSCPOut     | ``ClipDistance``         | N/A                   | ``ClipDistance``            |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | DSCPIn      | ``ClipDistance``         | N/A                   | ``ClipDistance``            |
-| SV_ClipDistance           +-------------+--------------------------+-----------------------+-----------------------------+
-|                           | DSOut       | ``ClipDistance``         | N/A                   | ``ClipDistance``            |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+| SV_ClipDistance           | DSOut       | ``ClipDistance``         | N/A                   | ``ClipDistance``            |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | GSVIn       | ``ClipDistance``         | N/A                   | ``ClipDistance``            |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | GSOut       | ``ClipDistance``         | N/A                   | ``ClipDistance``            |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | PSIn        | ``ClipDistance``         | N/A                   | ``ClipDistance``            |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+|                           | MSOut       | ``ClipDistance``         | N/A                   | ``ClipDistance``            |
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
 |                           | VSOut       | ``CullDistance``         | N/A                   | ``CullDistance``            |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
@@ -1304,14 +1310,16 @@ some system-value (SV) semantic strings will be translated into SPIR-V
 |                           | HSCPOut     | ``CullDistance``         | N/A                   | ``CullDistance``            |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | DSCPIn      | ``CullDistance``         | N/A                   | ``CullDistance``            |
-| SV_CullDistance           +-------------+--------------------------+-----------------------+-----------------------------+
-|                           | DSOut       | ``CullDistance``         | N/A                   | ``CullDistance``            |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+| SV_CullDistance           | DSOut       | ``CullDistance``         | N/A                   | ``CullDistance``            |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | GSVIn       | ``CullDistance``         | N/A                   | ``CullDistance``            |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | GSOut       | ``CullDistance``         | N/A                   | ``CullDistance``            |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | PSIn        | ``CullDistance``         | N/A                   | ``CullDistance``            |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+|                           | MSOut       | ``CullDistance``         | N/A                   | ``CullDistance``            |
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
 | SV_VertexID               | VSIn        | ``VertexIndex``          | N/A                   | ``Shader``                  |
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
@@ -1325,13 +1333,29 @@ some system-value (SV) semantic strings will be translated into SPIR-V
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
 | SV_IsFrontFace            | PSIn        | ``FrontFacing``          | N/A                   | ``Shader``                  |
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
-| SV_DispatchThreadID       | CSIn        | ``GlobalInvocationId``   | N/A                   | ``Shader``                  |
+|                           | CSIn        | ``GlobalInvocationId``   | N/A                   | ``Shader``                  |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+| SV_DispatchThreadID       | MSIn        | ``GlobalInvocationId``   | N/A                   | ``Shader``                  |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+|                           | ASIn        | ``GlobalInvocationId``   | N/A                   | ``Shader``                  |
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
-| SV_GroupID                | CSIn        | ``WorkgroupId``          | N/A                   | ``Shader``                  |
+|                           | CSIn        | ``WorkgroupId``          | N/A                   | ``Shader``                  |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+| SV_GroupID                | MSIn        | ``WorkgroupId``          | N/A                   | ``Shader``                  |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+|                           | ASIn        | ``WorkgroupId``          | N/A                   | ``Shader``                  |
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
-| SV_GroupThreadID          | CSIn        | ``LocalInvocationId``    | N/A                   | ``Shader``                  |
+|                           | CSIn        | ``LocalInvocationId``    | N/A                   | ``Shader``                  |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+| SV_GroupThreadID          | MSIn        | ``LocalInvocationId``    | N/A                   | ``Shader``                  |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+|                           | ASIn        | ``LocalInvocationId``    | N/A                   | ``Shader``                  |
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
-| SV_GroupIndex             | CSIn        | ``LocalInvocationIndex`` | N/A                   | ``Shader``                  |
+|                           | CSIn        | ``LocalInvocationIndex`` | N/A                   | ``Shader``                  |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+| SV_GroupIndex             | MSIn        | ``LocalInvocationIndex`` | N/A                   | ``Shader``                  |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+|                           | ASIn        | ``LocalInvocationIndex`` | N/A                   | ``Shader``                  |
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
 | SV_OutputControlPointID   | HSIn        | ``InvocationId``         | N/A                   | ``Tessellation``            |
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
@@ -1344,12 +1368,14 @@ some system-value (SV) semantic strings will be translated into SPIR-V
 |                           | PCIn        | ``PrimitiveId``          | N/A                   | ``Tessellation``            |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | DsIn        | ``PrimitiveId``          | N/A                   | ``Tessellation``            |
-| SV_PrimitiveID            +-------------+--------------------------+-----------------------+-----------------------------+
-|                           | GSIn        | ``PrimitiveId``          | N/A                   | ``Geometry``                |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+| SV_PrimitiveID            | GSIn        | ``PrimitiveId``          | N/A                   | ``Geometry``                |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | GSOut       | ``PrimitiveId``          | N/A                   | ``Geometry``                |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | PSIn        | ``PrimitiveId``          | N/A                   | ``Geometry``                |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+|                           | MSOut       | ``PrimitiveId``          | N/A                   | ``MeshShadingNV``           |
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
 |                           | PCOut       | ``TessLevelOuter``       | N/A                   | ``Tessellation``            |
 | SV_TessFactor             +-------------+--------------------------+-----------------------+-----------------------------+
@@ -1366,12 +1392,16 @@ some system-value (SV) semantic strings will be translated into SPIR-V
 | SV_Barycentrics           | PSIn        | ``BaryCoord*AMD``        | N/A                   | ``Shader``                  |
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
 |                           | GSOut       | ``Layer``                | N/A                   | ``Geometry``                |
-| SV_RenderTargetArrayIndex +-------------+--------------------------+-----------------------+-----------------------------+
-|                           | PSIn        | ``Layer``                | N/A                   | ``Geometry``                |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+| SV_RenderTargetArrayIndex | PSIn        | ``Layer``                | N/A                   | ``Geometry``                |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+|                           | MSOut       | ``Layer``                | N/A                   | ``MeshShadingNV``           |
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
 |                           | GSOut       | ``ViewportIndex``        | N/A                   | ``MultiViewport``           |
-| SV_ViewportArrayIndex     +-------------+--------------------------+-----------------------+-----------------------------+
-|                           | PSIn        | ``ViewportIndex``        | N/A                   | ``MultiViewport``           |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+| SV_ViewportArrayIndex     | PSIn        | ``ViewportIndex``        | N/A                   | ``MultiViewport``           |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+|                           | MSOut       | ``ViewportIndex``        | N/A                   | ``MeshShadingNV``           |
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
 |                           | PSIn        | ``SampleMask``           | N/A                   | ``Shader``                  |
 | SV_Coverage               +-------------+--------------------------+-----------------------+-----------------------------+
@@ -1383,11 +1413,13 @@ some system-value (SV) semantic strings will be translated into SPIR-V
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | HSIn        | ``ViewIndex``            | N/A                   | ``MultiView``               |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
-| SV_ViewID                 | DSIn        | ``ViewIndex``            | N/A                   | ``MultiView``               |
-|                           +-------------+--------------------------+-----------------------+-----------------------------+
+|                           | DSIn        | ``ViewIndex``            | N/A                   | ``MultiView``               |
+| SV_ViewID                 +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | GSIn        | ``ViewIndex``            | N/A                   | ``MultiView``               |
 |                           +-------------+--------------------------+-----------------------+-----------------------------+
 |                           | PSIn        | ``ViewIndex``            | N/A                   | ``MultiView``               |
+|                           +-------------+--------------------------+-----------------------+-----------------------------+
+|                           | MSIn        | ``ViewIndex``            | N/A                   | ``MultiView``               |
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
 | SV_ShadingRate            | PSIn        | ``FragSizeEXT``          | N/A                   | ``FragmentDensityEXT``      |
 +---------------------------+-------------+--------------------------+-----------------------+-----------------------------+
@@ -3069,6 +3101,80 @@ Callable Stage
     a.color = float4(0.0f,1.0f,0.0f,0.0f);
   }
 
+Mesh and Amplification Shaders
+------------------------------
+
+DirectX adds 2 new shader stages for using MeshShading pipeline namely Mesh and Amplification.
+Amplification shaders corresponds to Task Shaders in Vulkan.
+
+| Refer to following HLSL and SPIR-V specs for details:
+| https://docs.microsoft.com/<TBD>
+| https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/NV/SPV_NV_mesh_shader.asciidoc
+
+This section describes how Mesh and Amplification shaders are translated to SPIR-V for Vulkan.
+
+Entry Point Attributes
+~~~~~~~~~~~~~~~~~~~~~~
+The following HLSL attributes are attached to the main entry point of Mesh and/or Amplification
+shaders and are translated to SPIR-V execution modes according to the table below:
+
+.. table:: Mapping from HLSL attribute to SPIR-V execution mode
+
++--------------------+----------------+-------------------------+
+|  HLSL Attribute    |   Value        | SPIR-V Execution Mode   |
++====================+================+=========================+
+|                    | ``point``      | ``OutputPoints``        |
+|                    +----------------+-------------------------+
+| ``outputtopology`` | ``line``       | ``OutputLinesNV``       |
+|   (Mesh shader)    +----------------+-------------------------+
+|                    | ``triangle``   | ``OutputTrianglesNV``   |
++--------------------+----------------+-------------------------+
+| ``numthreads``     | ``X, Y, Z``    | ``LocalSize X, Y, Z``   |
+|                    | (X*Y*Z <= 128) |                         |
++--------------------+----------------+-------------------------+
+
+Intrinsics
+~~~~~~~~~~
+The following HLSL intrinsics are used in Mesh or Amplification shaders
+and are translated to SPIR-V intrinsics according to the table below:
+
+.. table:: Mapping from HLSL intrinsics to SPIR-V intrinsics
+
++-------------------------+--------------------+-----------------------------------------+
+|  HLSL Intrinsic         |  Parameters        | SPIR-V Intrinsic                        |
++=========================+====================+=========================================+
+| ``SetMeshOutputCounts`` | ``numVertices``    | ``PrimitiveCountNV numPrimitives``      |
+|     (Mesh shader)       | ``numPrimitives``  |                                         |
++-------------------------+--------------------+-----------------------------------------+
+|                         | ``ThreadX``        |                                         |
+| ``DispatchMesh``        | ``ThreadY``        |  ``OpControlBarrier``                   |
+| (Amplification shader)  | ``ThreadZ``        | ``TaskCountNV ThreadX*ThreadY*ThreadZ`` |
+|                         | ``MeshPayload``    |                                         |
++-------------------------+--------------------+-----------------------------------------+
+
+| *For DispatchMesh intrinsic, we also emit MeshPayload as output block with PerTaskNV decoration
+
+Mesh Interface Variables
+~~~~~~~~~~~~~~~~~~~~~~~~
+Interface variables are defined for Mesh shaders using HLSL modifiers.
+Following table gives high level overview of the mapping:
+
+.. table:: Mapping from HLSL modifiers to SPIR-V definitions
+
++-----------------+-------------------------------------------------------------------------+
+|  HLSL modifier  | SPIR-V definition                                                       |
++=================+=========================================================================+
+| ``indices``     | Maps to SPIR-V intrinsic ``PrimitiveIndicesNV``                         |  
+|                 | Defines SPIR-V Execution Mode ``OutputPrimitivesNV <array-size>``       |
++-----------------+-------------------------------------------------------------------------+
+| ``vertices``    | Maps to per-vertex out attributes                                       |
+|                 | Defines existing SPIR-V Execution Mode ``OutputVertices <array-size>``  |
++-----------------+-------------------------------------------------------------------------+
+| ``primitives``  | Maps to per-primitive out attributes with ``PerPrimitiveNV`` decoration |
++-----------------+-------------------------------------------------------------------------+
+| ``payload``     | Maps to per-task in attributes with ``PerTaskNV`` decoration            |
++-----------------+-------------------------------------------------------------------------+
+
 
 Raytracing in Vulkan and SPIRV
 ==============================

+ 1 - 1
external/SPIRV-Headers

@@ -1 +1 @@
-Subproject commit de99d4d834aeb51dd9f099baa285bd44fd04bb3d
+Subproject commit 29c11140baaf9f7fdaa39a583672c556bf1795a1

+ 1 - 1
external/SPIRV-Tools

@@ -1 +1 @@
-Subproject commit df86bb44fe476515f9a298bacd8e1d4a3522e989
+Subproject commit 5081512502df32b38a38712adb0d8c1b23bb1c2e

+ 73 - 12
include/dxc/DXIL/DxilConstants.h

@@ -68,17 +68,33 @@ namespace DXIL {
   const float kHSMaxTessFactorUpperBound = 64.0f;
   const unsigned kHSDefaultInputControlPointCount = 1;
   const unsigned kMaxCSThreadsPerGroup = 1024;
-  const unsigned kMaxCSThreadGroupX	= 1024;
-  const unsigned kMaxCSThreadGroupY	= 1024;
+  const unsigned kMaxCSThreadGroupX = 1024;
+  const unsigned kMaxCSThreadGroupY = 1024;
   const unsigned kMaxCSThreadGroupZ = 64;
   const unsigned kMinCSThreadGroupX = 1;
   const unsigned kMinCSThreadGroupY = 1;
   const unsigned kMinCSThreadGroupZ = 1;
   const unsigned kMaxCS4XThreadsPerGroup = 768;
-  const unsigned kMaxCS4XThreadGroupX	= 768;
-  const unsigned kMaxCS4XThreadGroupY	= 768;
+  const unsigned kMaxCS4XThreadGroupX = 768;
+  const unsigned kMaxCS4XThreadGroupY = 768;
   const unsigned kMaxTGSMSize = 8192*4;
   const unsigned kMaxGSOutputTotalScalars = 1024;
+  const unsigned kMaxMSASThreadsPerGroup = 128;
+  const unsigned kMaxMSASThreadGroupX = 128;
+  const unsigned kMaxMSASThreadGroupY = 128;
+  const unsigned kMaxMSASThreadGroupZ = 128;
+  const unsigned kMinMSASThreadGroupX = 1;
+  const unsigned kMinMSASThreadGroupY = 1;
+  const unsigned kMinMSASThreadGroupZ = 1;
+  const unsigned kMaxMSASPayloadSize = 16384;
+  const unsigned kMaxMSOutputPrimitiveCount = 256;
+  const unsigned kMaxMSOutputVertexCount = 256;
+  const unsigned kMaxMSOutputTotalScalars = 32768;
+  const unsigned kMaxMSInputOutputTotalScalars = 41984;
+  const unsigned kMaxMSVSigRows = 32;
+  const unsigned kMaxMSPSigRows = 32;
+  const unsigned kMaxMSTotalSigRows = 32;
+  const unsigned kMaxMSSMSize = 1024 * 28;
 
   const float kMaxMipLodBias = 15.99f;
   const float kMinMipLodBias = -16.0f;
@@ -116,7 +132,7 @@ namespace DXIL {
     Invalid = 0,
     Input,
     Output,
-    PatchConstant,
+    PatchConstOrPrim,
   };
 
   // Must match D3D11_SHADER_VERSION_TYPE
@@ -134,6 +150,8 @@ namespace DXIL {
     ClosestHit,
     Miss,
     Callable,
+    Mesh,
+    Amplification,
     Invalid,
   };
 
@@ -171,6 +189,7 @@ namespace DXIL {
     ViewID,
     Barycentrics,
     ShadingRate,
+    CullPrimitive,
     Invalid,
   };
   // SemanticKind-ENUM:END
@@ -195,6 +214,10 @@ namespace DXIL {
     PSIn, // Pixel Shader input
     PSOut, // Pixel Shader output
     CSIn, // Compute Shader input
+    MSIn, // Mesh Shader input
+    MSOut, // Mesh Shader vertices output
+    MSPOut, // Mesh Shader primitives output
+    ASIn, // Amplification Shader input
     Invalid,
   };
   // SigPointKind-ENUM:END
@@ -299,6 +322,9 @@ namespace DXIL {
   // OPCODE-ENUM:BEGIN
   // Enumeration for operations specified by DXIL
   enum class OpCode : unsigned {
+    // Amplification shader instructions
+    DispatchMesh = 173, // Amplification shader intrinsic DispatchMesh
+  
     // AnyHit Terminals
     AcceptHitAndEndSearch = 156, // Used in an any hit shader to abort the ray query and the intersection shader (if any). The current hit is committed and execution passes to the closest hit shader with the closest hit recorded so far
     IgnoreHit = 155, // Used in an any hit shader to reject an intersection and terminate the shader
@@ -334,7 +360,7 @@ namespace DXIL {
     BitcastI32toF32 = 126, // bitcast between different sizes
     BitcastI64toF64 = 128, // bitcast between different sizes
   
-    // Compute shader
+    // Compute/Mesh/Amplification shader
     FlattenedThreadIdInGroup = 96, // provides a flattened index for a given thread within a given group (SV_GroupIndex)
     GroupId = 94, // reads the group ID (SV_GroupID)
     ThreadId = 93, // reads the thread ID
@@ -392,6 +418,13 @@ namespace DXIL {
     // Library create handle from resource struct (like HL intrinsic)
     CreateHandleForLib = 160, // create resource handle from resource struct for library
   
+    // Mesh shader instructions
+    EmitIndices = 169, // emit a primitive's vertex indices in a mesh shader
+    GetMeshPayload = 170, // get the mesh payload which is from amplification shader
+    SetMeshOutputCounts = 168, // Mesh shader intrinsic SetMeshOutputCounts
+    StorePrimitiveOutput = 172, // stores the value to mesh shader primitive output
+    StoreVertexOutput = 171, // stores the value to mesh shader vertex output
+  
     // Other
     CycleCounterLegacy = 109, // CycleCounterLegacy
   
@@ -562,9 +595,9 @@ namespace DXIL {
     NumOpCodes_Dxil_1_2 = 141,
     NumOpCodes_Dxil_1_3 = 162,
     NumOpCodes_Dxil_1_4 = 165,
-    NumOpCodes_Dxil_1_5 = 168,
+    NumOpCodes_Dxil_1_5 = 174,
   
-    NumOpCodes = 168 // exclusive last value of enumeration
+    NumOpCodes = 174 // exclusive last value of enumeration
   };
   // OPCODE-ENUM:END
 
@@ -572,6 +605,9 @@ namespace DXIL {
   // OPCODECLASS-ENUM:BEGIN
   // Groups for DXIL operations with equivalent function templates
   enum class OpCodeClass : unsigned {
+    // Amplification shader instructions
+    DispatchMesh,
+  
     // AnyHit Terminals
     AcceptHitAndEndSearch,
     IgnoreHit,
@@ -593,7 +629,7 @@ namespace DXIL {
     BitcastI32toF32,
     BitcastI64toF64,
   
-    // Compute shader
+    // Compute/Mesh/Amplification shader
     FlattenedThreadIdInGroup,
     GroupId,
     ThreadId,
@@ -653,6 +689,13 @@ namespace DXIL {
     // Library create handle from resource struct (like HL intrinsic)
     CreateHandleForLib,
   
+    // Mesh shader instructions
+    EmitIndices,
+    GetMeshPayload,
+    SetMeshOutputCounts,
+    StorePrimitiveOutput,
+    StoreVertexOutput,
+  
     // Other
     CycleCounterLegacy,
   
@@ -778,9 +821,9 @@ namespace DXIL {
     NumOpClasses_Dxil_1_2 = 97,
     NumOpClasses_Dxil_1_3 = 118,
     NumOpClasses_Dxil_1_4 = 120,
-    NumOpClasses_Dxil_1_5 = 123,
+    NumOpClasses_Dxil_1_5 = 129,
   
-    NumOpClasses = 123 // exclusive last value of enumeration
+    NumOpClasses = 129 // exclusive last value of enumeration
   };
   // OPCODECLASS-ENUM:END
 
@@ -807,11 +850,12 @@ namespace DXIL {
     const unsigned kLoadInputColOpIdx = 3;
     const unsigned kLoadInputVertexIDOpIdx = 4;
 
-    // StoreOutput.
+    // StoreOutput, StoreVertexOutput, StorePrimitiveOutput
     const unsigned kStoreOutputIDOpIdx = 1;
     const unsigned kStoreOutputRowOpIdx = 2;
     const unsigned kStoreOutputColOpIdx = 3;
     const unsigned kStoreOutputValOpIdx = 4;
+    const unsigned kStoreOutputVPIDOpIdx = 5;
 
     // DomainLocation.
     const unsigned kDomainLocationColOpIdx = 1;
@@ -913,6 +957,14 @@ namespace DXIL {
 
     // Emit/Cut
     const unsigned kStreamEmitCutIDOpIdx = 1;
+
+    // StoreVectorOutput/StorePrimitiveOutput.
+    const unsigned kMSStoreOutputIDOpIdx = 1;
+    const unsigned kMSStoreOutputRowOpIdx = 2;
+    const unsigned kMSStoreOutputColOpIdx = 3;
+    const unsigned kMSStoreOutputVIdxOpIdx = 4;
+    const unsigned kMSStoreOutputValOpIdx = 5;
+
     // TODO: add operand index for all the OpCodeClass.
   }
 
@@ -1027,6 +1079,15 @@ namespace DXIL {
     LastEntry,
   };
 
+  enum class MeshOutputTopology
+  {
+    Undefined = 0,
+    Line = 1,
+    Triangle = 2,
+
+    LastEntry,
+  };
+
   // Tessellator partitioning, must match D3D_TESSELLATOR_PARTITIONING
   enum class TessellatorPartitioning : unsigned {
     Undefined = 0,

+ 17 - 1
include/dxc/DXIL/DxilFunctionProps.h

@@ -67,6 +67,19 @@ struct DxilFunctionProps {
       };
       unsigned attributeSizeInBytes;
     } Ray;
+    // Mesh shader.
+    struct {
+      unsigned numThreads[3];
+      unsigned maxVertexCount;
+      unsigned maxPrimitiveCount;
+      DXIL::MeshOutputTopology outputTopology;
+      // The following doesn't go into metadata
+      unsigned payloadByteSize;
+    } MS;
+    // Amplification shader.
+    struct {
+      unsigned numThreads[3];
+    } AS;
   } ShaderProps;
   DXIL::ShaderKind shaderKind;
   // TODO: Should we have an unmangled name here for ray tracing shaders?
@@ -77,7 +90,8 @@ struct DxilFunctionProps {
   bool IsDS() const     { return shaderKind == DXIL::ShaderKind::Domain; }
   bool IsCS() const     { return shaderKind == DXIL::ShaderKind::Compute; }
   bool IsGraphics() const {
-    return (shaderKind >= DXIL::ShaderKind::Pixel && shaderKind <= DXIL::ShaderKind::Domain);
+    return (shaderKind >= DXIL::ShaderKind::Pixel && shaderKind <= DXIL::ShaderKind::Domain) ||
+           shaderKind == DXIL::ShaderKind::Mesh || shaderKind == DXIL::ShaderKind::Amplification;
   }
   bool IsRayGeneration() const { return shaderKind == DXIL::ShaderKind::RayGeneration; }
   bool IsIntersection() const { return shaderKind == DXIL::ShaderKind::Intersection; }
@@ -88,6 +102,8 @@ struct DxilFunctionProps {
   bool IsRay() const {
     return (shaderKind >= DXIL::ShaderKind::RayGeneration && shaderKind <= DXIL::ShaderKind::Callable);
   }
+  bool IsMS() const { return shaderKind == DXIL::ShaderKind::Mesh; }
+  bool IsAS() const { return shaderKind == DXIL::ShaderKind::Amplification; }
 };
 
 } // namespace hlsl

+ 188 - 0
include/dxc/DXIL/DxilInstructions.h

@@ -5549,5 +5549,193 @@ struct DxilInst_WaveMultiPrefixBitCount {
   llvm::Value *get_mask3() const { return Instr->getOperand(5); }
   void set_mask3(llvm::Value *val) { Instr->setOperand(5, val); }
 };
+
+/// This instruction Mesh shader intrinsic SetMeshOutputCounts
+struct DxilInst_SetMeshOutputCounts {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_SetMeshOutputCounts(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::SetMeshOutputCounts);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (3 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_numVertices = 1,
+    arg_numPrimitives = 2,
+  };
+  // Accessors
+  llvm::Value *get_numVertices() const { return Instr->getOperand(1); }
+  void set_numVertices(llvm::Value *val) { Instr->setOperand(1, val); }
+  llvm::Value *get_numPrimitives() const { return Instr->getOperand(2); }
+  void set_numPrimitives(llvm::Value *val) { Instr->setOperand(2, val); }
+};
+
+/// This instruction emit a primitive's vertex indices in a mesh shader
+struct DxilInst_EmitIndices {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_EmitIndices(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::EmitIndices);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (5 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_PrimitiveIndex = 1,
+    arg_VertexIndex0 = 2,
+    arg_VertexIndex1 = 3,
+    arg_VertexIndex2 = 4,
+  };
+  // Accessors
+  llvm::Value *get_PrimitiveIndex() const { return Instr->getOperand(1); }
+  void set_PrimitiveIndex(llvm::Value *val) { Instr->setOperand(1, val); }
+  llvm::Value *get_VertexIndex0() const { return Instr->getOperand(2); }
+  void set_VertexIndex0(llvm::Value *val) { Instr->setOperand(2, val); }
+  llvm::Value *get_VertexIndex1() const { return Instr->getOperand(3); }
+  void set_VertexIndex1(llvm::Value *val) { Instr->setOperand(3, val); }
+  llvm::Value *get_VertexIndex2() const { return Instr->getOperand(4); }
+  void set_VertexIndex2(llvm::Value *val) { Instr->setOperand(4, val); }
+};
+
+/// This instruction get the mesh payload which is from amplification shader
+struct DxilInst_GetMeshPayload {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_GetMeshPayload(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::GetMeshPayload);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (1 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+};
+
+/// This instruction stores the value to mesh shader vertex output
+struct DxilInst_StoreVertexOutput {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_StoreVertexOutput(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::StoreVertexOutput);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (6 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_outputSigId = 1,
+    arg_rowIndex = 2,
+    arg_colIndex = 3,
+    arg_value = 4,
+    arg_vertexIndex = 5,
+  };
+  // Accessors
+  llvm::Value *get_outputSigId() const { return Instr->getOperand(1); }
+  void set_outputSigId(llvm::Value *val) { Instr->setOperand(1, val); }
+  llvm::Value *get_rowIndex() const { return Instr->getOperand(2); }
+  void set_rowIndex(llvm::Value *val) { Instr->setOperand(2, val); }
+  llvm::Value *get_colIndex() const { return Instr->getOperand(3); }
+  void set_colIndex(llvm::Value *val) { Instr->setOperand(3, val); }
+  llvm::Value *get_value() const { return Instr->getOperand(4); }
+  void set_value(llvm::Value *val) { Instr->setOperand(4, val); }
+  llvm::Value *get_vertexIndex() const { return Instr->getOperand(5); }
+  void set_vertexIndex(llvm::Value *val) { Instr->setOperand(5, val); }
+};
+
+/// This instruction stores the value to mesh shader primitive output
+struct DxilInst_StorePrimitiveOutput {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_StorePrimitiveOutput(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::StorePrimitiveOutput);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (6 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_outputSigId = 1,
+    arg_rowIndex = 2,
+    arg_colIndex = 3,
+    arg_value = 4,
+    arg_primitiveIndex = 5,
+  };
+  // Accessors
+  llvm::Value *get_outputSigId() const { return Instr->getOperand(1); }
+  void set_outputSigId(llvm::Value *val) { Instr->setOperand(1, val); }
+  llvm::Value *get_rowIndex() const { return Instr->getOperand(2); }
+  void set_rowIndex(llvm::Value *val) { Instr->setOperand(2, val); }
+  llvm::Value *get_colIndex() const { return Instr->getOperand(3); }
+  void set_colIndex(llvm::Value *val) { Instr->setOperand(3, val); }
+  llvm::Value *get_value() const { return Instr->getOperand(4); }
+  void set_value(llvm::Value *val) { Instr->setOperand(4, val); }
+  llvm::Value *get_primitiveIndex() const { return Instr->getOperand(5); }
+  void set_primitiveIndex(llvm::Value *val) { Instr->setOperand(5, val); }
+};
+
+/// This instruction Amplification shader intrinsic DispatchMesh
+struct DxilInst_DispatchMesh {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_DispatchMesh(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::DispatchMesh);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (5 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_threadGroupCountX = 1,
+    arg_threadGroupCountY = 2,
+    arg_threadGroupCountZ = 3,
+    arg_payload = 4,
+  };
+  // Accessors
+  llvm::Value *get_threadGroupCountX() const { return Instr->getOperand(1); }
+  void set_threadGroupCountX(llvm::Value *val) { Instr->setOperand(1, val); }
+  llvm::Value *get_threadGroupCountY() const { return Instr->getOperand(2); }
+  void set_threadGroupCountY(llvm::Value *val) { Instr->setOperand(2, val); }
+  llvm::Value *get_threadGroupCountZ() const { return Instr->getOperand(3); }
+  void set_threadGroupCountZ(llvm::Value *val) { Instr->setOperand(3, val); }
+  llvm::Value *get_payload() const { return Instr->getOperand(4); }
+  void set_payload(llvm::Value *val) { Instr->setOperand(4, val); }
+};
 // INSTR-HELPER:END
 } // namespace hlsl

+ 26 - 0
include/dxc/DXIL/DxilMetadataHelper.h

@@ -220,6 +220,8 @@ public:
   static const unsigned kDxilRayPayloadSizeTag  = 6;
   static const unsigned kDxilRayAttribSizeTag   = 7;
   static const unsigned kDxilShaderKindTag      = 8;
+  static const unsigned kDxilMSStateTag         = 9;
+  static const unsigned kDxilASStateTag         = 10;
 
   // GSState.
   static const unsigned kDxilGSStateNumFields               = 5;
@@ -244,6 +246,17 @@ public:
   static const unsigned kDxilHSStateTessellatorOutputPrimitive= 5;
   static const unsigned kDxilHSStateMaxTessellationFactor     = 6;
 
+  // MSState.
+  static const unsigned kDxilMSStateNumFields = 4;
+  static const unsigned kDxilMSStateNumThreads = 0;
+  static const unsigned kDxilMSStateMaxVertexCount = 1;
+  static const unsigned kDxilMSStateMaxPrimitiveCount = 2;
+  static const unsigned kDxilMSStateOutputTopology = 3;
+
+  // ASState.
+  static const unsigned kDxilASStateNumFields = 1;
+  static const unsigned kDxilASStateNumThreads = 0;
+
 public:
   /// Use this class to manipulate metadata of DXIL or high-level DX IR specific fields in the record.
   class ExtraPropertyHelper {
@@ -406,6 +419,19 @@ private:
                        DXIL::TessellatorPartitioning &TessPartitioning,
                        DXIL::TessellatorOutputPrimitive &TessOutputPrimitive,
                        float &MaxTessFactor);
+
+  llvm::MDTuple *EmitDxilMSState(const unsigned *NumThreads,
+                                 unsigned MaxVertexCount,
+                                 unsigned MaxPrimitiveCount,
+                                 DXIL::MeshOutputTopology OutputTopology);
+  void LoadDxilMSState(const llvm::MDOperand &MDO,
+                       unsigned *NumThreads,
+                       unsigned &MaxVertexCount,
+                       unsigned &MaxPrimitiveCount,
+                       DXIL::MeshOutputTopology &OutputTopology);
+
+  llvm::MDTuple *EmitDxilASState(const unsigned *NumThreads);
+  void LoadDxilASState(const llvm::MDOperand &MDO, unsigned *NumThreads);
 public:
   // Utility functions.
   static bool IsKnownNamedMetaData(const llvm::NamedMDNode &Node);

+ 12 - 2
include/dxc/DXIL/DxilModule.h

@@ -118,8 +118,8 @@ public:
   const DxilSignature &GetInputSignature() const;
   DxilSignature &GetOutputSignature();
   const DxilSignature &GetOutputSignature() const;
-  DxilSignature &GetPatchConstantSignature();
-  const DxilSignature &GetPatchConstantSignature() const;
+  DxilSignature &GetPatchConstOrPrimSignature();
+  const DxilSignature &GetPatchConstOrPrimSignature() const;
   const std::vector<uint8_t> &GetSerializedRootSignature() const;
   std::vector<uint8_t> &GetSerializedRootSignature();
 
@@ -273,6 +273,16 @@ public:
   float GetMaxTessellationFactor() const;
   void SetMaxTessellationFactor(float MaxTessellationFactor);
 
+  // Mesh shader
+  unsigned GetMaxOutputVertices() const;
+  void SetMaxOutputVertices(unsigned NumOVs);
+  unsigned GetMaxOutputPrimitives() const;
+  void SetMaxOutputPrimitives(unsigned NumOPs);
+  DXIL::MeshOutputTopology GetMeshOutputTopology() const;
+  void SetMeshOutputTopology(DXIL::MeshOutputTopology MeshOutputTopology);
+  unsigned GetPayloadByteSize() const;
+  void SetPayloadByteSize(unsigned Size);
+
   // AutoBindingSpace also enables automatic binding for libraries if set.
   // UINT_MAX == unset
   void SetAutoBindingSpace(uint32_t Space);

+ 3 - 1
include/dxc/DXIL/DxilShaderModel.h

@@ -40,6 +40,8 @@ public:
   bool IsCS() const     { return m_Kind == Kind::Compute; }
   bool IsLib() const    { return m_Kind == Kind::Library; }
   bool IsRay() const    { return m_Kind >= Kind::RayGeneration && m_Kind <= Kind::Callable; }
+  bool IsMS() const     { return m_Kind == Kind::Mesh; }
+  bool IsAS() const     { return m_Kind == Kind::Amplification; }
   bool IsValid() const;
   bool IsValidForDxil() const;
   bool IsValidForModule() const;
@@ -96,7 +98,7 @@ private:
               unsigned m_NumInputRegs, unsigned m_NumOutputRegs,
               bool m_bUAVs, bool m_bTypedUavs, unsigned m_UAVRegsLim);
 
-  static const unsigned kNumShaderModels = 63;
+  static const unsigned kNumShaderModels = 65;
   static const ShaderModel ms_ShaderModels[kNumShaderModels];
 
   static const ShaderModel *GetInvalid();

+ 1 - 1
include/dxc/DXIL/DxilSigPoint.h

@@ -34,7 +34,7 @@ public:
 
   bool IsInput() const { return m_SignatureKind == DXIL::SignatureKind::Input; }
   bool IsOutput() const { return m_SignatureKind == DXIL::SignatureKind::Output; }
-  bool IsPatchConstant() const { return m_SignatureKind == DXIL::SignatureKind::PatchConstant; }
+  bool IsPatchConstOrPrim() const { return m_SignatureKind == DXIL::SignatureKind::PatchConstOrPrim; }
 
   Kind GetKind() const { return m_Kind; }
   const char *GetName() const { return m_pszName; }

+ 78 - 52
include/dxc/DXIL/DxilSigPoint.inl

@@ -19,25 +19,29 @@ namespace hlsl {
 // for compatibility purposes.
 // <py::lines('SIGPOINT-TABLE')>hctdb_instrhelp.get_sigpoint_table()</py>
 // SIGPOINT-TABLE:BEGIN
-//   SigPoint, Related, ShaderKind, PackingKind,    SignatureKind
+//   SigPoint, Related, ShaderKind,    PackingKind,    SignatureKind
 #define DO_SIGPOINTS(ROW) \
-  ROW(VSIn,     Invalid, Vertex,     InputAssembler, Input) \
-  ROW(VSOut,    Invalid, Vertex,     Vertex,         Output) \
-  ROW(PCIn,     HSCPIn,  Hull,       None,           Invalid) \
-  ROW(HSIn,     HSCPIn,  Hull,       None,           Invalid) \
-  ROW(HSCPIn,   Invalid, Hull,       Vertex,         Input) \
-  ROW(HSCPOut,  Invalid, Hull,       Vertex,         Output) \
-  ROW(PCOut,    Invalid, Hull,       PatchConstant,  PatchConstant) \
-  ROW(DSIn,     Invalid, Domain,     PatchConstant,  PatchConstant) \
-  ROW(DSCPIn,   Invalid, Domain,     Vertex,         Input) \
-  ROW(DSOut,    Invalid, Domain,     Vertex,         Output) \
-  ROW(GSVIn,    Invalid, Geometry,   Vertex,         Input) \
-  ROW(GSIn,     GSVIn,   Geometry,   None,           Invalid) \
-  ROW(GSOut,    Invalid, Geometry,   Vertex,         Output) \
-  ROW(PSIn,     Invalid, Pixel,      Vertex,         Input) \
-  ROW(PSOut,    Invalid, Pixel,      Target,         Output) \
-  ROW(CSIn,     Invalid, Compute,    None,           Invalid) \
-  ROW(Invalid,  Invalid, Invalid,    Invalid,        Invalid)
+  ROW(VSIn,     Invalid, Vertex,        InputAssembler, Input) \
+  ROW(VSOut,    Invalid, Vertex,        Vertex,         Output) \
+  ROW(PCIn,     HSCPIn,  Hull,          None,           Invalid) \
+  ROW(HSIn,     HSCPIn,  Hull,          None,           Invalid) \
+  ROW(HSCPIn,   Invalid, Hull,          Vertex,         Input) \
+  ROW(HSCPOut,  Invalid, Hull,          Vertex,         Output) \
+  ROW(PCOut,    Invalid, Hull,          PatchConstant,  PatchConstOrPrim) \
+  ROW(DSIn,     Invalid, Domain,        PatchConstant,  PatchConstOrPrim) \
+  ROW(DSCPIn,   Invalid, Domain,        Vertex,         Input) \
+  ROW(DSOut,    Invalid, Domain,        Vertex,         Output) \
+  ROW(GSVIn,    Invalid, Geometry,      Vertex,         Input) \
+  ROW(GSIn,     GSVIn,   Geometry,      None,           Invalid) \
+  ROW(GSOut,    Invalid, Geometry,      Vertex,         Output) \
+  ROW(PSIn,     Invalid, Pixel,         Vertex,         Input) \
+  ROW(PSOut,    Invalid, Pixel,         Target,         Output) \
+  ROW(CSIn,     Invalid, Compute,       None,           Invalid) \
+  ROW(MSIn,     Invalid, Mesh,          None,           Invalid) \
+  ROW(MSOut,    Invalid, Mesh,          Vertex,         Output) \
+  ROW(MSPOut,   Invalid, Mesh,          Vertex,         PatchConstOrPrim) \
+  ROW(ASIn,     Invalid, Amplification, None,           Invalid) \
+  ROW(Invalid,  Invalid, Invalid,       Invalid,        Invalid)
 // SIGPOINT-TABLE:END
 
 const SigPoint SigPoint::ms_SigPoints[kNumSigPointRecords] = {
@@ -49,38 +53,39 @@ const SigPoint SigPoint::ms_SigPoints[kNumSigPointRecords] = {
 
 // <py::lines('INTERPRETATION-TABLE')>hctdb_instrhelp.get_interpretation_table()</py>
 // INTERPRETATION-TABLE:BEGIN
-//   Semantic,               VSIn,         VSOut,    PCIn,         HSIn,         HSCPIn,   HSCPOut,  PCOut,      DSIn,         DSCPIn,   DSOut,    GSVIn,    GSIn,         GSOut,    PSIn,          PSOut,         CSIn
+//   Semantic,               VSIn,         VSOut,    PCIn,         HSIn,         HSCPIn,   HSCPOut,  PCOut,      DSIn,         DSCPIn,   DSOut,    GSVIn,    GSIn,         GSOut,    PSIn,          PSOut,         CSIn,     MSIn,         MSOut,        MSPOut,  ASIn
 #define DO_INTERPRETATION_TABLE(ROW) \
-  ROW(Arbitrary,              Arb,          Arb,      NA,           NA,           Arb,      Arb,      Arb,        Arb,          Arb,      Arb,      Arb,      NA,           Arb,      Arb,           NA,            NA) \
-  ROW(VertexID,               SV,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NA,            NA) \
-  ROW(InstanceID,             SV,           Arb,      NA,           NA,           Arb,      Arb,      NA,         NA,           Arb,      Arb,      Arb,      NA,           Arb,      Arb,           NA,            NA) \
-  ROW(Position,               Arb,          SV,       NA,           NA,           SV,       SV,       Arb,        Arb,          SV,       SV,       SV,       NA,           SV,       SV,            NA,            NA) \
-  ROW(RenderTargetArrayIndex, Arb,          SV,       NA,           NA,           SV,       SV,       Arb,        Arb,          SV,       SV,       SV,       NA,           SV,       SV,            NA,            NA) \
-  ROW(ViewPortArrayIndex,     Arb,          SV,       NA,           NA,           SV,       SV,       Arb,        Arb,          SV,       SV,       SV,       NA,           SV,       SV,            NA,            NA) \
-  ROW(ClipDistance,           Arb,          ClipCull, NA,           NA,           ClipCull, ClipCull, Arb,        Arb,          ClipCull, ClipCull, ClipCull, NA,           ClipCull, ClipCull,      NA,            NA) \
-  ROW(CullDistance,           Arb,          ClipCull, NA,           NA,           ClipCull, ClipCull, Arb,        Arb,          ClipCull, ClipCull, ClipCull, NA,           ClipCull, ClipCull,      NA,            NA) \
-  ROW(OutputControlPointID,   NA,           NA,       NA,           NotInSig,     NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NA,            NA) \
-  ROW(DomainLocation,         NA,           NA,       NA,           NA,           NA,       NA,       NA,         NotInSig,     NA,       NA,       NA,       NA,           NA,       NA,            NA,            NA) \
-  ROW(PrimitiveID,            NA,           NA,       NotInSig,     NotInSig,     NA,       NA,       NA,         NotInSig,     NA,       NA,       NA,       Shadow,       SGV,      SGV,           NA,            NA) \
-  ROW(GSInstanceID,           NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NotInSig,     NA,       NA,            NA,            NA) \
-  ROW(SampleIndex,            NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       Shadow _41,    NA,            NA) \
-  ROW(IsFrontFace,            NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           SGV,      SGV,           NA,            NA) \
-  ROW(Coverage,               NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NotInSig _50,  NotPacked _41, NA) \
-  ROW(InnerCoverage,          NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NotInSig _50,  NA,            NA) \
-  ROW(Target,                 NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            Target,        NA) \
-  ROW(Depth,                  NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NotPacked,     NA) \
-  ROW(DepthLessEqual,         NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NotPacked _50, NA) \
-  ROW(DepthGreaterEqual,      NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NotPacked _50, NA) \
-  ROW(StencilRef,             NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NotPacked _50, NA) \
-  ROW(DispatchThreadID,       NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NA,            NotInSig) \
-  ROW(GroupID,                NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NA,            NotInSig) \
-  ROW(GroupIndex,             NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NA,            NotInSig) \
-  ROW(GroupThreadID,          NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NA,            NotInSig) \
-  ROW(TessFactor,             NA,           NA,       NA,           NA,           NA,       NA,       TessFactor, TessFactor,   NA,       NA,       NA,       NA,           NA,       NA,            NA,            NA) \
-  ROW(InsideTessFactor,       NA,           NA,       NA,           NA,           NA,       NA,       TessFactor, TessFactor,   NA,       NA,       NA,       NA,           NA,       NA,            NA,            NA) \
-  ROW(ViewID,                 NotInSig _61, NA,       NotInSig _61, NotInSig _61, NA,       NA,       NA,         NotInSig _61, NA,       NA,       NA,       NotInSig _61, NA,       NotInSig _61,  NA,            NA) \
-  ROW(Barycentrics,           NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NotPacked _61, NA,            NA) \
-  ROW(ShadingRate,            NA,           SV _64,   NA,           NA,           SV _64,   SV _64,   NA,         NA,           SV _64,   SV _64,   SV _64,   NA,           SV _64,   SV _64,        NA,            NA)
+  ROW(Arbitrary,              Arb,          Arb,      NA,           NA,           Arb,      Arb,      Arb,        Arb,          Arb,      Arb,      Arb,      NA,           Arb,      Arb,           NA,            NA,       NA,           Arb _65,      Arb _65, NA) \
+  ROW(VertexID,               SV,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NA,            NA,       NA,           NA,           NA,      NA) \
+  ROW(InstanceID,             SV,           Arb,      NA,           NA,           Arb,      Arb,      NA,         NA,           Arb,      Arb,      Arb,      NA,           Arb,      Arb,           NA,            NA,       NA,           NA,           NA,      NA) \
+  ROW(Position,               Arb,          SV,       NA,           NA,           SV,       SV,       Arb,        Arb,          SV,       SV,       SV,       NA,           SV,       SV,            NA,            NA,       NA,           SV _65,       NA,      NA) \
+  ROW(RenderTargetArrayIndex, Arb,          SV,       NA,           NA,           SV,       SV,       Arb,        Arb,          SV,       SV,       SV,       NA,           SV,       SV,            NA,            NA,       NA,           NA,           SV _65,  NA) \
+  ROW(ViewPortArrayIndex,     Arb,          SV,       NA,           NA,           SV,       SV,       Arb,        Arb,          SV,       SV,       SV,       NA,           SV,       SV,            NA,            NA,       NA,           NA,           SV _65,  NA) \
+  ROW(ClipDistance,           Arb,          ClipCull, NA,           NA,           ClipCull, ClipCull, Arb,        Arb,          ClipCull, ClipCull, ClipCull, NA,           ClipCull, ClipCull,      NA,            NA,       NA,           ClipCull _65, NA,      NA) \
+  ROW(CullDistance,           Arb,          ClipCull, NA,           NA,           ClipCull, ClipCull, Arb,        Arb,          ClipCull, ClipCull, ClipCull, NA,           ClipCull, ClipCull,      NA,            NA,       NA,           ClipCull _65, NA,      NA) \
+  ROW(OutputControlPointID,   NA,           NA,       NA,           NotInSig,     NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NA,            NA,       NA,           NA,           NA,      NA) \
+  ROW(DomainLocation,         NA,           NA,       NA,           NA,           NA,       NA,       NA,         NotInSig,     NA,       NA,       NA,       NA,           NA,       NA,            NA,            NA,       NA,           NA,           NA,      NA) \
+  ROW(PrimitiveID,            NA,           NA,       NotInSig,     NotInSig,     NA,       NA,       NA,         NotInSig,     NA,       NA,       NA,       Shadow,       SGV,      SGV,           NA,            NA,       NA,           NA,           SV _65,  NA) \
+  ROW(GSInstanceID,           NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NotInSig,     NA,       NA,            NA,            NA,       NA,           NA,           NA,      NA) \
+  ROW(SampleIndex,            NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       Shadow _41,    NA,            NA,       NA,           NA,           NA,      NA) \
+  ROW(IsFrontFace,            NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           SGV,      SGV,           NA,            NA,       NA,           NA,           NA,      NA) \
+  ROW(Coverage,               NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NotInSig _50,  NotPacked _41, NA,       NA,           NA,           NA,      NA) \
+  ROW(InnerCoverage,          NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NotInSig _50,  NA,            NA,       NA,           NA,           NA,      NA) \
+  ROW(Target,                 NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            Target,        NA,       NA,           NA,           NA,      NA) \
+  ROW(Depth,                  NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NotPacked,     NA,       NA,           NA,           NA,      NA) \
+  ROW(DepthLessEqual,         NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NotPacked _50, NA,       NA,           NA,           NA,      NA) \
+  ROW(DepthGreaterEqual,      NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NotPacked _50, NA,       NA,           NA,           NA,      NA) \
+  ROW(StencilRef,             NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NotPacked _50, NA,       NA,           NA,           NA,      NA) \
+  ROW(DispatchThreadID,       NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NA,            NotInSig, NotInSig _65, NA,           NA,      NotInSig _65) \
+  ROW(GroupID,                NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NA,            NotInSig, NotInSig _65, NA,           NA,      NotInSig _65) \
+  ROW(GroupIndex,             NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NA,            NotInSig, NotInSig _65, NA,           NA,      NotInSig _65) \
+  ROW(GroupThreadID,          NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NA,            NotInSig, NotInSig _65, NA,           NA,      NotInSig _65) \
+  ROW(TessFactor,             NA,           NA,       NA,           NA,           NA,       NA,       TessFactor, TessFactor,   NA,       NA,       NA,       NA,           NA,       NA,            NA,            NA,       NA,           NA,           NA,      NA) \
+  ROW(InsideTessFactor,       NA,           NA,       NA,           NA,           NA,       NA,       TessFactor, TessFactor,   NA,       NA,       NA,       NA,           NA,       NA,            NA,            NA,       NA,           NA,           NA,      NA) \
+  ROW(ViewID,                 NotInSig _61, NA,       NotInSig _61, NotInSig _61, NA,       NA,       NA,         NotInSig _61, NA,       NA,       NA,       NotInSig _61, NA,       NotInSig _61,  NA,            NA,       NotInSig _65, NA,           NA,      NA) \
+  ROW(Barycentrics,           NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NotPacked _61, NA,            NA,       NA,           NA,           NA,      NA) \
+  ROW(ShadingRate,            NA,           SV _64,   NA,           NA,           SV _64,   SV _64,   NA,         NA,           SV _64,   SV _64,   SV _64,   NA,           SV _64,   SV _64,        NA,            NA,       NA,           NA,           NA,      NA) \
+  ROW(CullPrimitive,          NA,           NA,       NA,           NA,           NA,       NA,       NA,         NA,           NA,       NA,       NA,       NA,           NA,       NA,            NA,            NA,       NA,           NA,           SV _65,  NA)
 // INTERPRETATION-TABLE:END
 
 const VersionedSemanticInterpretation SigPoint::ms_SemanticInterpretationTable[(unsigned)DXIL::SemanticKind::Invalid][(unsigned)SigPoint::Kind::Invalid] = {
@@ -88,7 +93,8 @@ const VersionedSemanticInterpretation SigPoint::ms_SemanticInterpretationTable[(
 #define _50 ,5,0
 #define _61 ,6,1
 #define _64 ,6,4
-#define DO_ROW(SEM, VSIn, VSOut, PCIn, HSIn, HSCPIn, HSCPOut, PCOut, DSIn, DSCPIn, DSOut, GSVIn, GSIn, GSOut, PSIn, PSOut, CSIn) \
+#define _65 ,6,5
+#define DO_ROW(SEM, VSIn, VSOut, PCIn, HSIn, HSCPIn, HSCPOut, PCOut, DSIn, DSCPIn, DSOut, GSVIn, GSIn, GSOut, PSIn, PSOut, CSIn, MSIn, MSOut, MSPOut, ASIn) \
   { VersionedSemanticInterpretation(DXIL::SemanticInterpretationKind::VSIn), \
     VersionedSemanticInterpretation(DXIL::SemanticInterpretationKind::VSOut), \
     VersionedSemanticInterpretation(DXIL::SemanticInterpretationKind::PCIn), \
@@ -105,6 +111,10 @@ const VersionedSemanticInterpretation SigPoint::ms_SemanticInterpretationTable[(
     VersionedSemanticInterpretation(DXIL::SemanticInterpretationKind::PSIn), \
     VersionedSemanticInterpretation(DXIL::SemanticInterpretationKind::PSOut), \
     VersionedSemanticInterpretation(DXIL::SemanticInterpretationKind::CSIn), \
+    VersionedSemanticInterpretation(DXIL::SemanticInterpretationKind::MSIn), \
+    VersionedSemanticInterpretation(DXIL::SemanticInterpretationKind::MSOut), \
+    VersionedSemanticInterpretation(DXIL::SemanticInterpretationKind::MSPOut), \
+    VersionedSemanticInterpretation(DXIL::SemanticInterpretationKind::ASIn), \
   },
   DO_INTERPRETATION_TABLE(DO_ROW)
 #undef DO_ROW
@@ -179,7 +189,7 @@ DXIL::SigPointKind SigPoint::GetKind(DXIL::ShaderKind shaderKind, DXIL::Signatur
     switch (sigKind) {
     case DXIL::SignatureKind::Input: return DXIL::SigPointKind::HSCPIn;
     case DXIL::SignatureKind::Output: return DXIL::SigPointKind::HSCPOut;
-    case DXIL::SignatureKind::PatchConstant: return DXIL::SigPointKind::PCOut;
+    case DXIL::SignatureKind::PatchConstOrPrim: return DXIL::SigPointKind::PCOut;
     default:
       break;
     }
@@ -188,7 +198,7 @@ DXIL::SigPointKind SigPoint::GetKind(DXIL::ShaderKind shaderKind, DXIL::Signatur
     switch (sigKind) {
     case DXIL::SignatureKind::Input: return DXIL::SigPointKind::DSCPIn;
     case DXIL::SignatureKind::Output: return DXIL::SigPointKind::DSOut;
-    case DXIL::SignatureKind::PatchConstant: return DXIL::SigPointKind::DSIn;
+    case DXIL::SignatureKind::PatchConstOrPrim: return DXIL::SigPointKind::DSIn;
     default:
       break;
     }
@@ -216,6 +226,22 @@ DXIL::SigPointKind SigPoint::GetKind(DXIL::ShaderKind shaderKind, DXIL::Signatur
       break;
     }
     break;
+  case DXIL::ShaderKind::Mesh:
+    switch (sigKind) {
+    case DXIL::SignatureKind::Input: return DXIL::SigPointKind::MSIn;
+    case DXIL::SignatureKind::Output: return DXIL::SigPointKind::MSOut;
+    case DXIL::SignatureKind::PatchConstOrPrim: return DXIL::SigPointKind::MSPOut;
+    default:
+      break;
+    }
+    break;
+  case DXIL::ShaderKind::Amplification:
+    switch (sigKind) {
+    case DXIL::SignatureKind::Input: return DXIL::SigPointKind::ASIn;
+    default:
+      break;
+    }
+    break;
   default:
     break;
   }

+ 4 - 2
include/dxc/DXIL/DxilSignature.h

@@ -52,6 +52,8 @@ public:
 
   static bool ShouldBeAllocated(DXIL::SemanticInterpretationKind);
 
+  unsigned GetRowCount() const;
+
 private:
   DXIL::SigPointKind m_sigPointKind;
   std::vector<std::unique_ptr<DxilSignatureElement> > m_Elements;
@@ -62,12 +64,12 @@ struct DxilEntrySignature {
   DxilEntrySignature(DXIL::ShaderKind shaderKind, bool useMinPrecision)
       : InputSignature(shaderKind, DxilSignature::Kind::Input, useMinPrecision),
         OutputSignature(shaderKind, DxilSignature::Kind::Output, useMinPrecision),
-        PatchConstantSignature(shaderKind, DxilSignature::Kind::PatchConstant, useMinPrecision) {
+        PatchConstOrPrimSignature(shaderKind, DxilSignature::Kind::PatchConstOrPrim, useMinPrecision) {
   }
   DxilEntrySignature(const DxilEntrySignature &src);
   DxilSignature InputSignature;
   DxilSignature OutputSignature;
-  DxilSignature PatchConstantSignature;
+  DxilSignature PatchConstOrPrimSignature;
 };
 
 } // namespace hlsl

+ 1 - 1
include/dxc/DXIL/DxilSignatureElement.h

@@ -50,7 +50,7 @@ public:
 
   bool IsInput() const;
   bool IsOutput() const;
-  bool IsPatchConstant() const;
+  bool IsPatchConstOrPrim() const;
   const char *GetName() const;
   unsigned GetRows() const;
   void SetRows(unsigned Rows);

+ 4 - 0
include/dxc/DXIL/DxilTypeSystem.h

@@ -123,6 +123,10 @@ enum class DxilParamInputQual {
   OutStream2,
   OutStream3,
   InputPrimitive,
+  OutIndices,
+  OutVertices,
+  OutPrimitives,
+  InPayload,
 };
 
 /// Use this class to represent type annotation for function parameter.

+ 91 - 74
include/dxc/DxilContainer/DxilPipelineStateValidation.h

@@ -49,6 +49,13 @@ struct PSInfo {
   char DepthOutput;
   char SampleFrequency;
 };
+struct MSInfo {
+  uint32_t GroupSharedBytesUsed;
+  uint32_t GroupSharedBytesDependentOnViewID;
+  uint32_t PayloadSizeInBytes;
+  uint16_t MaxOutputVertices;
+  uint16_t MaxOutputPrimitives;
+};
 
 // Versioning is additive and based on size
 struct PSVRuntimeInfo0
@@ -59,6 +66,7 @@ struct PSVRuntimeInfo0
     DSInfo DS;
     GSInfo GS;
     PSInfo PS;
+    MSInfo MS;
   };
   uint32_t MinimumExpectedWaveLaneCount;  // minimum lane count required, 0 if unused
   uint32_t MaximumExpectedWaveLaneCount;  // maximum lane count required, 0xffffffff if unused
@@ -79,6 +87,8 @@ enum class PSVShaderKind : uint8_t    // DXIL::ShaderKind
   ClosestHit,
   Miss,
   Callable,
+  Mesh,
+  Amplification,
   Invalid,
 };
 
@@ -88,16 +98,19 @@ struct PSVRuntimeInfo1 : public PSVRuntimeInfo0
   uint8_t UsesViewID;
   union {
     uint16_t MaxVertexCount;          // MaxVertexCount for GS only (max 1024)
-    uint8_t SigPatchConstantVectors;  // Output for HS; Input for DS
+    uint8_t SigPatchConstOrPrimVectors;  // Output for HS; Input for DS; Primitive output for MS
   };
 
   // PSVSignatureElement counts
   uint8_t SigInputElements;
   uint8_t SigOutputElements;
-  uint8_t SigPatchConstantElements;
+  uint8_t SigPatchConstOrPrimElements;
 
   // Number of packed vectors per signature
-  uint8_t SigInputVectors;
+  union {
+    uint8_t SigInputVectors;
+    uint8_t MeshOutputTopology;
+  };
   uint8_t SigOutputVectors[4];      // Array for GS Stream Out Index
 };
 
@@ -320,9 +333,9 @@ struct PSVInitInfo
     UsesViewID(0),
     SigInputElements(0),
     SigOutputElements(0),
-    SigPatchConstantElements(0),
+    SigPatchConstOrPrimElements(0),
     SigInputVectors(0),
-    SigPatchConstantVectors(0)
+    SigPatchConstOrPrimVectors(0)
   {}
   uint32_t PSVVersion;
   uint32_t ResourceCount;
@@ -332,9 +345,9 @@ struct PSVInitInfo
   uint8_t UsesViewID;
   uint8_t SigInputElements;
   uint8_t SigOutputElements;
-  uint8_t SigPatchConstantElements;
+  uint8_t SigPatchConstOrPrimElements;
   uint8_t SigInputVectors;
-  uint8_t SigPatchConstantVectors;
+  uint8_t SigPatchConstOrPrimVectors;
   uint8_t SigOutputVectors[4] = {0, 0, 0, 0};
 };
 
@@ -351,9 +364,9 @@ class DxilPipelineStateValidation
   uint32_t m_uPSVSignatureElementSize;
   void* m_pSigInputElements;
   void* m_pSigOutputElements;
-  void* m_pSigPatchConstantElements;
+  void* m_pSigPatchConstOrPrimElements;
   uint32_t* m_pViewIDOutputMask;
-  uint32_t* m_pViewIDPCOutputMask;
+  uint32_t* m_pViewIDPCOrPrimOutputMask;
   uint32_t* m_pInputToOutputTable;
   uint32_t* m_pInputToPCOutputTable;
   uint32_t* m_pPCInputToOutputTable;
@@ -371,9 +384,9 @@ public:
     m_uPSVSignatureElementSize(0),
     m_pSigInputElements(nullptr),
     m_pSigOutputElements(nullptr),
-    m_pSigPatchConstantElements(nullptr),
+    m_pSigPatchConstOrPrimElements(nullptr),
     m_pViewIDOutputMask(nullptr),
-    m_pViewIDPCOutputMask(nullptr),
+    m_pViewIDPCOrPrimOutputMask(nullptr),
     m_pInputToOutputTable(nullptr),
     m_pInputToPCOutputTable(nullptr),
     m_pPCInputToOutputTable(nullptr)
@@ -394,28 +407,28 @@ public:
   //    uint32_t SemanticIndexTableEntries (number of dwords)
   //    If SemanticIndexTableEntries:
   //      { semantic index } * SemanticIndexTableEntries
-  //    If SigInputElements || SigOutputElements || SigPatchConstantElements:
+  //    If SigInputElements || SigOutputElements || SigPatchConstOrPrimElements:
   //      uint32_t PSVSignatureElement_size
   //      { PSVSignatureElementN structure } * SigInputElements
   //      { PSVSignatureElementN structure } * SigOutputElements
-  //      { PSVSignatureElementN structure } * SigPatchConstantElements
+  //      { PSVSignatureElementN structure } * SigPatchConstOrPrimElements
   //    If (UsesViewID):
   //      For (i : each stream index 0-3):
   //        If (SigOutputVectors[i] non-zero):
   //          { uint32_t * PSVComputeMaskDwordsFromVectors(SigOutputVectors[i]) }
   //            - Outputs affected by ViewID as a bitmask
-  //      If (HS and SigPatchConstantVectors non-zero):
-  //        { uint32_t * PSVComputeMaskDwordsFromVectors(SigPatchConstantVectors) }
+  //      If (HS and SigPatchConstOrPrimVectors non-zero):
+  //        { uint32_t * PSVComputeMaskDwordsFromVectors(SigPatchConstOrPrimVectors) }
   //          - PCOutputs affected by ViewID as a bitmask
   //    For (i : each stream index 0-3):
   //      If (SigInputVectors and SigOutputVectors[i] non-zero):
   //        { PSVComputeInputOutputTableSize(SigInputVectors, SigOutputVectors[i]) }
   //          - Outputs affected by inputs as a table of bitmasks
-  //    If (HS and SigPatchConstantVectors and SigInputVectors non-zero):
-  //      { PSVComputeInputOutputTableSize(SigInputVectors, SigPatchConstantVectors) }
+  //    If (HS and SigPatchConstOrPrimVectors and SigInputVectors non-zero):
+  //      { PSVComputeInputOutputTableSize(SigInputVectors, SigPatchConstOrPrimVectors) }
   //        - Patch constant outputs affected by inputs as a table of bitmasks
-  //    If (DS and SigOutputVectors[0] and SigPatchConstantVectors non-zero):
-  //      { PSVComputeInputOutputTableSize(SigPatchConstantVectors, SigOutputVectors[0]) }
+  //    If (DS and SigOutputVectors[0] and SigPatchConstOrPrimVectors non-zero):
+  //      { PSVComputeInputOutputTableSize(SigPatchConstOrPrimVectors, SigOutputVectors[0]) }
   //        - Outputs affected by patch constant inputs as a table of bitmasks
   // returns true if no errors occurred.
   bool InitFromPSV0(const void* pBits, uint32_t size) {
@@ -466,14 +479,14 @@ public:
       pCurBits += sizeof(uint32_t) * m_SemanticIndexTable.Entries;
 
       // Dxil Signature Elements
-      if (m_pPSVRuntimeInfo1->SigInputElements || m_pPSVRuntimeInfo1->SigOutputElements || m_pPSVRuntimeInfo1->SigPatchConstantElements) {
+      if (m_pPSVRuntimeInfo1->SigInputElements || m_pPSVRuntimeInfo1->SigOutputElements || m_pPSVRuntimeInfo1->SigPatchConstOrPrimElements) {
         minsize += sizeof(uint32_t);
         if (!(size >= minsize)) return false;
         m_uPSVSignatureElementSize = *(uint32_t*)pCurBits;
         if (m_uPSVSignatureElementSize < sizeof(PSVSignatureElement0))
           return false;   // Illegal: Size smaller than first version
         pCurBits += sizeof(uint32_t);
-        minsize += m_uPSVSignatureElementSize * (m_pPSVRuntimeInfo1->SigInputElements + m_pPSVRuntimeInfo1->SigOutputElements + m_pPSVRuntimeInfo1->SigPatchConstantElements);
+        minsize += m_uPSVSignatureElementSize * (m_pPSVRuntimeInfo1->SigInputElements + m_pPSVRuntimeInfo1->SigOutputElements + m_pPSVRuntimeInfo1->SigPatchConstOrPrimElements);
         if (!(size >= minsize)) return false;
       }
       if (m_pPSVRuntimeInfo1->SigInputElements) {
@@ -484,9 +497,9 @@ public:
         m_pSigOutputElements = (PSVSignatureElement0*)pCurBits;
         pCurBits += m_uPSVSignatureElementSize * m_pPSVRuntimeInfo1->SigOutputElements;
       }
-      if (m_pPSVRuntimeInfo1->SigPatchConstantElements) {
-        m_pSigPatchConstantElements = (PSVSignatureElement0*)pCurBits;
-        pCurBits += m_uPSVSignatureElementSize * m_pPSVRuntimeInfo1->SigPatchConstantElements;
+      if (m_pPSVRuntimeInfo1->SigPatchConstOrPrimElements) {
+        m_pSigPatchConstOrPrimElements = (PSVSignatureElement0*)pCurBits;
+        pCurBits += m_uPSVSignatureElementSize * m_pPSVRuntimeInfo1->SigPatchConstOrPrimElements;
       }
 
       // ViewID dependencies
@@ -501,11 +514,11 @@ public:
           if (!IsGS())
             break;
         }
-        if (IsHS() && m_pPSVRuntimeInfo1->SigPatchConstantVectors) {
-          minsize += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(m_pPSVRuntimeInfo1->SigPatchConstantVectors);
+        if ((IsHS() || IsMS()) && m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors) {
+          minsize += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors);
           if (!(size >= minsize)) return false;
-          m_pViewIDPCOutputMask = (uint32_t*)pCurBits;
-          pCurBits += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(m_pPSVRuntimeInfo1->SigPatchConstantVectors);
+          m_pViewIDPCOrPrimOutputMask = (uint32_t*)pCurBits;
+          pCurBits += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors);
         }
       }
 
@@ -520,17 +533,17 @@ public:
         if (!IsGS())
           break;
       }
-      if (IsHS() && m_pPSVRuntimeInfo1->SigPatchConstantVectors > 0 && m_pPSVRuntimeInfo1->SigInputVectors > 0) {
-        minsize += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigPatchConstantVectors);
+      if ((IsHS() || IsMS()) && m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors > 0 && m_pPSVRuntimeInfo1->SigInputVectors > 0) {
+        minsize += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors);
         if (!(size >= minsize)) return false;
         m_pInputToPCOutputTable = (uint32_t*)pCurBits;
-        pCurBits += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigPatchConstantVectors);
+        pCurBits += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors);
       }
-      if (IsDS() && m_pPSVRuntimeInfo1->SigOutputVectors[0] > 0 && m_pPSVRuntimeInfo1->SigPatchConstantVectors > 0) {
-        minsize += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigPatchConstantVectors, m_pPSVRuntimeInfo1->SigOutputVectors[0]);
+      if (IsDS() && m_pPSVRuntimeInfo1->SigOutputVectors[0] > 0 && m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors > 0) {
+        minsize += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors, m_pPSVRuntimeInfo1->SigOutputVectors[0]);
         if (!(size >= minsize)) return false;
         m_pPCInputToOutputTable = (uint32_t*)pCurBits;
-        pCurBits += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigPatchConstantVectors, m_pPSVRuntimeInfo1->SigOutputVectors[0]);
+        pCurBits += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors, m_pPSVRuntimeInfo1->SigOutputVectors[0]);
       }
     }
     return true;
@@ -567,12 +580,12 @@ public:
     if (initInfo.PSVVersion > 0) {
       size += sizeof(uint32_t) + PSVALIGN4(initInfo.StringTable.Size);
       size += sizeof(uint32_t) + sizeof(uint32_t) * initInfo.SemanticIndexTable.Entries;
-      if (initInfo.SigInputElements || initInfo.SigOutputElements || initInfo.SigPatchConstantElements) {
+      if (initInfo.SigInputElements || initInfo.SigOutputElements || initInfo.SigPatchConstOrPrimElements) {
         size += sizeof(uint32_t);   // PSVSignatureElement_size
       }
       size += m_uPSVSignatureElementSize * initInfo.SigInputElements;
       size += m_uPSVSignatureElementSize * initInfo.SigOutputElements;
-      size += m_uPSVSignatureElementSize * initInfo.SigPatchConstantElements;
+      size += m_uPSVSignatureElementSize * initInfo.SigPatchConstOrPrimElements;
 
       if (initInfo.UsesViewID) {
         for (unsigned i = 0; i < 4; i++) {
@@ -580,21 +593,23 @@ public:
           if (initInfo.ShaderStage != PSVShaderKind::Geometry)
             break;
         }
-        if (initInfo.ShaderStage == PSVShaderKind::Hull)
-          size += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(initInfo.SigPatchConstantVectors);
+        if (initInfo.ShaderStage == PSVShaderKind::Hull || initInfo.ShaderStage == PSVShaderKind::Mesh)
+          size += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(initInfo.SigPatchConstOrPrimVectors);
       }
-      for (unsigned i = 0; i < 4; i++) {
-        if (initInfo.SigOutputVectors[i] > 0 && initInfo.SigInputVectors > 0) {
-          size += PSVComputeInputOutputTableSize(initInfo.SigInputVectors, initInfo.SigOutputVectors[i]);
-          if (initInfo.ShaderStage != PSVShaderKind::Geometry)
-            break;
+      if (initInfo.ShaderStage != PSVShaderKind::Mesh && initInfo.ShaderStage != PSVShaderKind::Amplification) {
+        for (unsigned i = 0; i < 4; i++) {
+          if (initInfo.SigOutputVectors[i] > 0 && initInfo.SigInputVectors > 0) {
+            size += PSVComputeInputOutputTableSize(initInfo.SigInputVectors, initInfo.SigOutputVectors[i]);
+            if (initInfo.ShaderStage != PSVShaderKind::Geometry)
+              break;
+          }
+        }
+        if (initInfo.ShaderStage == PSVShaderKind::Hull && initInfo.SigPatchConstOrPrimVectors > 0 && initInfo.SigInputVectors > 0) {
+          size += PSVComputeInputOutputTableSize(initInfo.SigInputVectors, initInfo.SigPatchConstOrPrimVectors);
+        }
+        if (initInfo.ShaderStage == PSVShaderKind::Domain && initInfo.SigOutputVectors[0] > 0 && initInfo.SigPatchConstOrPrimVectors > 0) {
+          size += PSVComputeInputOutputTableSize(initInfo.SigPatchConstOrPrimVectors, initInfo.SigOutputVectors[0]);
         }
-      }
-      if (initInfo.ShaderStage == PSVShaderKind::Hull && initInfo.SigPatchConstantVectors > 0 && initInfo.SigInputVectors > 0) {
-        size += PSVComputeInputOutputTableSize(initInfo.SigInputVectors, initInfo.SigPatchConstantVectors);
-      }
-      if (initInfo.ShaderStage == PSVShaderKind::Domain && initInfo.SigOutputVectors[0] > 0 && initInfo.SigPatchConstantVectors > 0) {
-        size += PSVComputeInputOutputTableSize(initInfo.SigPatchConstantVectors, initInfo.SigOutputVectors[0]);
       }
     }
 
@@ -634,11 +649,11 @@ public:
       m_pPSVRuntimeInfo1->UsesViewID = initInfo.UsesViewID;
       m_pPSVRuntimeInfo1->SigInputElements = initInfo.SigInputElements;
       m_pPSVRuntimeInfo1->SigOutputElements = initInfo.SigOutputElements;
-      m_pPSVRuntimeInfo1->SigPatchConstantElements = initInfo.SigPatchConstantElements;
+      m_pPSVRuntimeInfo1->SigPatchConstOrPrimElements = initInfo.SigPatchConstOrPrimElements;
       m_pPSVRuntimeInfo1->SigInputVectors = initInfo.SigInputVectors;
       memcpy(m_pPSVRuntimeInfo1->SigOutputVectors, initInfo.SigOutputVectors, 4);
-      if (IsHS() || IsDS()) {
-        m_pPSVRuntimeInfo1->SigPatchConstantVectors = initInfo.SigPatchConstantVectors;
+      if (IsHS() || IsDS() || IsMS()) {
+        m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors = initInfo.SigPatchConstOrPrimVectors;
       }
 
       // Note: if original size was unaligned, padding has already been zero initialized
@@ -657,7 +672,7 @@ public:
       pCurBits += sizeof(uint32_t) * m_SemanticIndexTable.Entries;
 
       // Dxil Signature Elements
-      if (m_pPSVRuntimeInfo1->SigInputElements || m_pPSVRuntimeInfo1->SigOutputElements || m_pPSVRuntimeInfo1->SigPatchConstantElements) {
+      if (m_pPSVRuntimeInfo1->SigInputElements || m_pPSVRuntimeInfo1->SigOutputElements || m_pPSVRuntimeInfo1->SigPatchConstOrPrimElements) {
         *(uint32_t*)pCurBits = m_uPSVSignatureElementSize;
         pCurBits += sizeof(uint32_t);
       }
@@ -669,9 +684,9 @@ public:
         m_pSigOutputElements = (PSVSignatureElement0*)pCurBits;
         pCurBits += m_uPSVSignatureElementSize * m_pPSVRuntimeInfo1->SigOutputElements;
       }
-      if (m_pPSVRuntimeInfo1->SigPatchConstantElements) {
-        m_pSigPatchConstantElements = (PSVSignatureElement0*)pCurBits;
-        pCurBits += m_uPSVSignatureElementSize * m_pPSVRuntimeInfo1->SigPatchConstantElements;
+      if (m_pPSVRuntimeInfo1->SigPatchConstOrPrimElements) {
+        m_pSigPatchConstOrPrimElements = (PSVSignatureElement0*)pCurBits;
+        pCurBits += m_uPSVSignatureElementSize * m_pPSVRuntimeInfo1->SigPatchConstOrPrimElements;
       }
 
       // ViewID dependencies
@@ -684,9 +699,9 @@ public:
           if (!IsGS())
             break;
         }
-        if (IsHS() && m_pPSVRuntimeInfo1->SigPatchConstantVectors) {
-          m_pViewIDPCOutputMask = (uint32_t*)pCurBits;
-          pCurBits += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(m_pPSVRuntimeInfo1->SigPatchConstantVectors);
+        if ((IsHS() || IsMS()) && m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors) {
+          m_pViewIDPCOrPrimOutputMask = (uint32_t*)pCurBits;
+          pCurBits += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors);
         }
       }
 
@@ -699,13 +714,13 @@ public:
         if (!IsGS())
           break;
       }
-      if (IsHS() && m_pPSVRuntimeInfo1->SigPatchConstantVectors > 0 && m_pPSVRuntimeInfo1->SigInputVectors > 0) {
+      if ((IsHS() || IsMS()) && m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors > 0 && m_pPSVRuntimeInfo1->SigInputVectors > 0) {
         m_pInputToPCOutputTable = (uint32_t*)pCurBits;
-        pCurBits += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigPatchConstantVectors);
+        pCurBits += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors);
       }
-      if (IsDS() && m_pPSVRuntimeInfo1->SigOutputVectors[0] > 0 && m_pPSVRuntimeInfo1->SigPatchConstantVectors > 0) {
+      if (IsDS() && m_pPSVRuntimeInfo1->SigOutputVectors[0] > 0 && m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors > 0) {
         m_pPCInputToOutputTable = (uint32_t*)pCurBits;
-        pCurBits += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigPatchConstantVectors, m_pPSVRuntimeInfo1->SigOutputVectors[0]);
+        pCurBits += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors, m_pPSVRuntimeInfo1->SigOutputVectors[0]);
       }
     }
 
@@ -747,9 +762,9 @@ public:
       return m_pPSVRuntimeInfo1->SigOutputElements;
     return 0;
   }
-  uint32_t GetSigPatchConstantElements() const {
+  uint32_t GetSigPatchConstOrPrimElements() const {
     if (m_pPSVRuntimeInfo1)
-      return m_pPSVRuntimeInfo1->SigPatchConstantElements;
+      return m_pPSVRuntimeInfo1->SigPatchConstOrPrimElements;
     return 0;
   }
   PSVSignatureElement0* GetInputElement0(uint32_t index) const {
@@ -770,11 +785,11 @@ public:
     }
     return nullptr;
   }
-  PSVSignatureElement0* GetPatchConstantElement0(uint32_t index) const {
-    if (m_pPSVRuntimeInfo1 && m_pSigPatchConstantElements &&
-        index < m_pPSVRuntimeInfo1->SigPatchConstantElements &&
+  PSVSignatureElement0* GetPatchConstOrPrimElement0(uint32_t index) const {
+    if (m_pPSVRuntimeInfo1 && m_pSigPatchConstOrPrimElements &&
+        index < m_pPSVRuntimeInfo1->SigPatchConstOrPrimElements &&
         sizeof(PSVSignatureElement0) <= m_uPSVSignatureElementSize) {
-      return (PSVSignatureElement0*)((uint8_t*)m_pSigPatchConstantElements +
+      return (PSVSignatureElement0*)((uint8_t*)m_pSigPatchConstOrPrimElements +
         (index * m_uPSVSignatureElementSize));
     }
     return nullptr;
@@ -795,6 +810,8 @@ public:
   bool IsGS() const { return GetShaderKind() == PSVShaderKind::Geometry; }
   bool IsPS() const { return GetShaderKind() == PSVShaderKind::Pixel; }
   bool IsCS() const { return GetShaderKind() == PSVShaderKind::Compute; }
+  bool IsMS() const { return GetShaderKind() == PSVShaderKind::Mesh; }
+  bool IsAS() const { return GetShaderKind() == PSVShaderKind::Amplification; }
 
   // ViewID dependencies
   PSVComponentMask GetViewIDOutputMask(unsigned streamIndex = 0) const {
@@ -803,9 +820,9 @@ public:
     return PSVComponentMask(m_pViewIDOutputMask, m_pPSVRuntimeInfo1->SigOutputVectors[streamIndex]);
   }
   PSVComponentMask GetViewIDPCOutputMask() const {
-    if (!IsHS() || !m_pViewIDPCOutputMask || !m_pPSVRuntimeInfo1 || !m_pPSVRuntimeInfo1->SigPatchConstantVectors)
+    if ((!IsHS() && !IsMS()) || !m_pViewIDPCOrPrimOutputMask || !m_pPSVRuntimeInfo1 || !m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors)
       return PSVComponentMask();
-    return PSVComponentMask(m_pViewIDPCOutputMask, m_pPSVRuntimeInfo1->SigPatchConstantVectors);
+    return PSVComponentMask(m_pViewIDPCOrPrimOutputMask, m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors);
   }
 
   // Input to Output dependencies
@@ -816,14 +833,14 @@ public:
     return PSVDependencyTable();
   }
   PSVDependencyTable GetInputToPCOutputTable() const {
-    if (IsHS() && m_pInputToPCOutputTable && m_pPSVRuntimeInfo1) {
-      return PSVDependencyTable(m_pInputToPCOutputTable, m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigPatchConstantVectors);
+    if ((IsHS() || IsMS()) && m_pInputToPCOutputTable && m_pPSVRuntimeInfo1) {
+      return PSVDependencyTable(m_pInputToPCOutputTable, m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors);
     }
     return PSVDependencyTable();
   }
   PSVDependencyTable GetPCInputToOutputTable() const {
     if (IsDS() && m_pPCInputToOutputTable && m_pPSVRuntimeInfo1) {
-      return PSVDependencyTable(m_pPCInputToOutputTable, m_pPSVRuntimeInfo1->SigPatchConstantVectors, m_pPSVRuntimeInfo1->SigOutputVectors[0]);
+      return PSVDependencyTable(m_pPCInputToOutputTable, m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors, m_pPSVRuntimeInfo1->SigOutputVectors[0]);
     }
     return PSVDependencyTable();
   }

+ 3 - 3
include/dxc/HLSL/ComputeViewIdState.h

@@ -55,15 +55,15 @@ struct DxilViewIdStateData {
 
   unsigned m_NumInputSigScalars  = 0;
   unsigned m_NumOutputSigScalars[kNumStreams] = {0,0,0,0};
-  unsigned m_NumPCSigScalars     = 0;
+  unsigned m_NumPCOrPrimSigScalars     = 0;
 
   // Set of scalar outputs dependent on ViewID.
   OutputsDependentOnViewIdType m_OutputsDependentOnViewId[kNumStreams];
-  OutputsDependentOnViewIdType m_PCOutputsDependentOnViewId;
+  OutputsDependentOnViewIdType m_PCOrPrimOutputsDependentOnViewId;
 
   // Set of scalar inputs contributing to computation of scalar outputs.
   InputsContributingToOutputType m_InputsContributingToOutputs[kNumStreams];
-  InputsContributingToOutputType m_InputsContributingToPCOutputs; // HS PC only.
+  InputsContributingToOutputType m_InputsContributingToPCOrPrimOutputs; // HS PC and MS Prim only.
   InputsContributingToOutputType m_PCInputsContributingToOutputs; // DS only.
 
   bool m_bUsesViewId = false;

+ 17 - 0
include/dxc/HLSL/DxilValidation.h

@@ -85,6 +85,9 @@ enum class ValidationRule : unsigned {
   InstrMinPrecisonBitCast, // Bitcast on minprecison types is not allowed
   InstrMipLevelForGetDimension, // Use mip level on buffer when GetDimensions
   InstrMipOnUAVLoad, // uav load don't support mipLevel/sampleIndex
+  InstrMissingSetMeshOutputCounts, // Missing SetMeshOutputCounts call.
+  InstrMultipleGetMeshPayload, // GetMeshPayload cannot be called multiple times.
+  InstrMultipleSetMeshOutputCounts, // SetMeshOUtputCounts cannot be called multiple times.
   InstrNoGenericPtrAddrSpaceCast, // Address space cast between pointer types must have one part to be generic address space
   InstrNoIDivByZero, // No signed integer division by zero
   InstrNoIndefiniteAcos, // No indefinite arccosine
@@ -93,6 +96,9 @@ enum class ValidationRule : unsigned {
   InstrNoIndefiniteLog, // No indefinite logarithm
   InstrNoReadingUninitialized, // Instructions should not read uninitialized value
   InstrNoUDivByZero, // No unsigned integer division by zero
+  InstrNonDominatingDispatchMesh, // Non-Dominating DispatchMesh call.
+  InstrNonDominatingSetMeshOutputCounts, // Non-Dominating SetMeshOutputCounts call.
+  InstrNotOnceDispatchMesh, // DispatchMesh must be called exactly once in an Amplification shader.
   InstrOffsetOnUAVLoad, // uav load don't support offset
   InstrOload, // DXIL intrinsic overload must be valid
   InstrOnlyOneAllocConsume, // RWStructuredBuffers may increment or decrement their counters, but not both.
@@ -190,6 +196,7 @@ enum class ValidationRule : unsigned {
 
   // Shader model
   Sm64bitRawBufferLoadStore, // i64/f64 rawBufferLoad/Store overloads are allowed after SM 6.3
+  SmAmplificationShaderPayloadSize, // For shader '%0', payload size is greater than %1
   SmAppendAndConsumeOnSameUAV, // BufferUpdateCounter inc and dec on a given UAV (%d) cannot both be in the same shader for shader model less than 5.1.
   SmCBufferArrayOffsetAlignment, // CBuffer array offset must be aligned to 16-bytes
   SmCBufferElementOverflow, // CBuffer elements must not overflow
@@ -197,6 +204,7 @@ enum class ValidationRule : unsigned {
   SmCBufferTemplateTypeMustBeStruct, // D3D12 constant/texture buffer template element can only be a struct
   SmCSNoSignatures, // Compute shaders must not have shader signatures.
   SmCompletePosition, // Not all elements of SV_Position were written
+  SmConstantInterpMode, // Interpolation mode must be constant for MS primitive output.
   SmCounterOnlyOnStructBuf, // BufferUpdateCounter valid only on structured buffers
   SmDSInputControlPointCountRange, // DS input control point count must be [0..%0].  %1 specified
   SmDomainLocationIdxOOB, // DomainLocation component index out of bounds for the domain.
@@ -213,8 +221,17 @@ enum class ValidationRule : unsigned {
   SmInvalidResourceKind, // Invalid resources kind
   SmInvalidTextureKindOnUAV, // Texture2DMS[Array] or TextureCube[Array] resources are not supported with UAVs
   SmIsoLineOutputPrimitiveMismatch, // Hull Shader declared with IsoLine Domain must specify output primitive point or line. Triangle_cw or triangle_ccw output are not compatible with the IsoLine Domain.
+  SmMaxMSSMSize, // Total Thread Group Shared Memory storage is %0, exceeded %1
   SmMaxTGSMSize, // Total Thread Group Shared Memory storage is %0, exceeded %1
   SmMaxTheadGroup, // Declared Thread Group Count %0 (X*Y*Z) is beyond the valid maximum of %1
+  SmMeshPSigRowCount, // For shader '%0', primitive output signatures are taking up more than %1 rows
+  SmMeshShaderInOutSize, // For shader '%0', input plus output size is greater than %1
+  SmMeshShaderMaxPrimitiveCount, // MS max primitive output count must be [0..%0].  %1 specified
+  SmMeshShaderMaxVertexCount, // MS max vertex output count must be [0..%0].  %1 specified
+  SmMeshShaderOutputSize, // For shader '%0', vertex plus primitive output size is greater than %1
+  SmMeshShaderPayloadSize, // For shader '%0', payload size is greater than %1
+  SmMeshTotalSigRowCount, // For shader '%0', vertex and primitive output signatures are taking up more than %1 rows
+  SmMeshVSigRowCount, // For shader '%0', vertex output signatures are taking up more than %1 rows
   SmMultiStreamMustBePoint, // When multiple GS output streams are used they must be pointlists
   SmName, // Target shader model name must be known
   SmNoInterpMode, // Interpolation mode must be undefined for VS input/PS output/patch constant.

+ 6 - 0
include/dxc/HLSL/HLOperations.h

@@ -336,6 +336,12 @@ const unsigned kTraceRayPayLoadOpIdx = 8;
 // ReportIntersection.
 const unsigned kReportIntersectionAttributeOpIdx = 3;
 
+// DispatchMesh
+const unsigned kDispatchMeshOpThreadX = 1;
+const unsigned kDispatchMeshOpThreadY = 2;
+const unsigned kDispatchMeshOpThreadZ = 3;
+const unsigned kDispatchMeshOpPayload = 4;
+
 } // namespace HLOperandIndex
 
 llvm::Function *GetOrCreateHLFunction(llvm::Module &M,

+ 4 - 4
include/dxc/HLSL/ViewIDPipelineValidation.inl

@@ -306,9 +306,9 @@ public:
         [&](unsigned i) -> PSVSignatureElement {
         return PSV.GetSignatureElement(PSV.GetOutputElement0(i));
       });
-      CopyElements(pcSig, DXIL::SigPointKind::PCOut, PSV.GetSigPatchConstantElements(), 0,
+      CopyElements(pcSig, DXIL::SigPointKind::PCOut, PSV.GetSigPatchConstOrPrimElements(), 0,
         [&](unsigned i) -> PSVSignatureElement {
-        return PSV.GetSignatureElement(PSV.GetPatchConstantElement0(i));
+        return PSV.GetSignatureElement(PSV.GetPatchConstOrPrimElement0(i));
       });
 
       // Propagate prior mask through input-output dependencies
@@ -345,9 +345,9 @@ public:
                     [&](unsigned i) -> PSVSignatureElement {
                       return PSV.GetSignatureElement(PSV.GetInputElement0(i));
                     });
-      CopyElements( pcSig, DXIL::SigPointKind::DSIn, PSV.GetSigPatchConstantElements(), 0,
+      CopyElements( pcSig, DXIL::SigPointKind::DSIn, PSV.GetSigPatchConstOrPrimElements(), 0,
                     [&](unsigned i) -> PSVSignatureElement {
-                      return PSV.GetSignatureElement(PSV.GetPatchConstantElement0(i));
+                      return PSV.GetSignatureElement(PSV.GetPatchConstOrPrimElement0(i));
                     });
 
       // Merge prior and input signatures, update prior mask size if necessary

+ 2 - 0
include/dxc/HlslIntrinsicOp.h

@@ -30,6 +30,7 @@ import hctdb_instrhelp
   IOP_D3DCOLORtoUBYTE4,
   IOP_DeviceMemoryBarrier,
   IOP_DeviceMemoryBarrierWithGroupSync,
+  IOP_DispatchMesh,
   IOP_DispatchRaysDimensions,
   IOP_DispatchRaysIndex,
   IOP_EvaluateAttributeAtSample,
@@ -78,6 +79,7 @@ import hctdb_instrhelp
   IOP_RayTCurrent,
   IOP_RayTMin,
   IOP_ReportHit,
+  IOP_SetMeshOutputCounts,
   IOP_TraceRay,
   IOP_WaveActiveAllEqual,
   IOP_WaveActiveAllTrue,

+ 1 - 1
include/dxc/Support/HLSLOptions.td

@@ -291,7 +291,7 @@ def Oconfig : CommaJoined<["-"], "Oconfig=">, Group<spirv_Group>, Flags<[CoreOpt
 // fxc-based flags that don't match those previously defined.
 
 def target_profile : JoinedOrSeparate<["-", "/"], "T">, Flags<[CoreOption]>, Group<hlslcomp_Group>, MetaVarName<"<profile>">,
-  HelpText<"Set target profile. \n\t<profile>: ps_6_0, ps_6_1, ps_6_2, ps_6_3, ps_6_4, ps_6_5, \n\t\t vs_6_0, vs_6_1, vs_6_2, vs_6_3, vs_6_4, vs_6_5, \n\t\t cs_6_0, cs_6_1, cs_6_2, cs_6_3, cs_6_4, cs_6_5, \n\t\t gs_6_0, gs_6_1, gs_6_2, gs_6_3, gs_6_4, gs_6_5, \n\t\t ds_6_0, ds_6_1, ds_6_2, ds_6_3, ds_6_4, ds_6_5, \n\t\t hs_6_0, hs_6_1, hs_6_2, hs_6_3, hs_6_4, hs_6_5, \n\t\t lib_6_3, lib_6_4, lib_6_5">;
+  HelpText<"Set target profile. \n\t<profile>: ps_6_0, ps_6_1, ps_6_2, ps_6_3, ps_6_4, ps_6_5, \n\t\t vs_6_0, vs_6_1, vs_6_2, vs_6_3, vs_6_4, vs_6_5, \n\t\t cs_6_0, cs_6_1, cs_6_2, cs_6_3, cs_6_4, cs_6_5, \n\t\t gs_6_0, gs_6_1, gs_6_2, gs_6_3, gs_6_4, gs_6_5, \n\t\t ds_6_0, ds_6_1, ds_6_2, ds_6_3, ds_6_4, ds_6_5, \n\t\t hs_6_0, hs_6_1, hs_6_2, hs_6_3, hs_6_4, hs_6_5, \n\t\t lib_6_3, lib_6_4, lib_6_5, ms_6_5, as_6_5">;
 def entrypoint :  JoinedOrSeparate<["-", "/"], "E">, Flags<[CoreOption]>, Group<hlslcomp_Group>,
   HelpText<"Entry point name">;
 // /I <include> - already defined above

+ 1 - 0
include/dxc/Support/SPIRVOptions.h

@@ -56,6 +56,7 @@ struct SpirvCodeGenOptions {
   SpirvLayoutRule cBufferLayoutRule;
   SpirvLayoutRule sBufferLayoutRule;
   SpirvLayoutRule tBufferLayoutRule;
+  SpirvLayoutRule ampPayloadLayoutRule;
   llvm::StringRef stageIoOrder;
   llvm::StringRef targetEnv;
   llvm::SmallVector<int32_t, 4> bShift;

+ 128 - 5
lib/DXIL/DxilMetadataHelper.cpp

@@ -318,13 +318,13 @@ MDTuple *DxilMDHelper::EmitDxilSignatures(const DxilEntrySignature &EntrySig) {
 
   const DxilSignature &InputSig = EntrySig.InputSignature;
   const DxilSignature &OutputSig = EntrySig.OutputSignature;
-  const DxilSignature &PCSig = EntrySig.PatchConstantSignature;
+  const DxilSignature &PCPSig = EntrySig.PatchConstOrPrimSignature;
 
-  if (!InputSig.GetElements().empty() || !OutputSig.GetElements().empty() || !PCSig.GetElements().empty()) {
+  if (!InputSig.GetElements().empty() || !OutputSig.GetElements().empty() || !PCPSig.GetElements().empty()) {
     Metadata *MDVals[kDxilNumSignatureFields];
     MDVals[kDxilInputSignature]         = EmitSignatureMetadata(InputSig);
     MDVals[kDxilOutputSignature]        = EmitSignatureMetadata(OutputSig);
-    MDVals[kDxilPatchConstantSignature] = EmitSignatureMetadata(PCSig);
+    MDVals[kDxilPatchConstantSignature] = EmitSignatureMetadata(PCPSig);
 
     pSignatureTupleMD = MDNode::get(m_Ctx, MDVals);
   }
@@ -354,14 +354,14 @@ void DxilMDHelper::LoadDxilSignatures(const MDOperand &MDO, DxilEntrySignature &
     return;
   DxilSignature &InputSig = EntrySig.InputSignature;
   DxilSignature &OutputSig = EntrySig.OutputSignature;
-  DxilSignature &PCSig = EntrySig.PatchConstantSignature;
+  DxilSignature &PCPSig = EntrySig.PatchConstOrPrimSignature;
   const MDTuple *pTupleMD = dyn_cast<MDTuple>(MDO.get());
   IFTBOOL(pTupleMD != nullptr, DXC_E_INCORRECT_DXIL_METADATA);
   IFTBOOL(pTupleMD->getNumOperands() == kDxilNumSignatureFields, DXC_E_INCORRECT_DXIL_METADATA);
 
   LoadSignatureMetadata(pTupleMD->getOperand(kDxilInputSignature),         InputSig);
   LoadSignatureMetadata(pTupleMD->getOperand(kDxilOutputSignature),        OutputSig);
-  LoadSignatureMetadata(pTupleMD->getOperand(kDxilPatchConstantSignature), PCSig);
+  LoadSignatureMetadata(pTupleMD->getOperand(kDxilPatchConstantSignature), PCPSig);
 }
 
 MDTuple *DxilMDHelper::EmitSignatureMetadata(const DxilSignature &Sig) {
@@ -1024,6 +1024,28 @@ const Function *DxilMDHelper::LoadDxilFunctionProps(const MDTuple *pProps,
       props->ShaderProps.Ray.attributeSizeInBytes =
         ConstMDToUint32(pProps->getOperand(idx++));
     break;
+  case DXIL::ShaderKind::Mesh:
+    props->ShaderProps.MS.numThreads[0] =
+      ConstMDToUint32(pProps->getOperand(idx++));
+    props->ShaderProps.MS.numThreads[1] =
+      ConstMDToUint32(pProps->getOperand(idx++));
+    props->ShaderProps.MS.numThreads[2] =
+      ConstMDToUint32(pProps->getOperand(idx++));
+    props->ShaderProps.MS.maxVertexCount =
+      ConstMDToUint32(pProps->getOperand(idx++));
+    props->ShaderProps.MS.maxPrimitiveCount =
+      ConstMDToUint32(pProps->getOperand(idx++));
+    props->ShaderProps.MS.outputTopology =
+      (DXIL::MeshOutputTopology)ConstMDToUint32(pProps->getOperand(idx++));
+    break;
+  case DXIL::ShaderKind::Amplification:
+    props->ShaderProps.AS.numThreads[0] =
+      ConstMDToUint32(pProps->getOperand(idx++));
+    props->ShaderProps.AS.numThreads[1] =
+      ConstMDToUint32(pProps->getOperand(idx++));
+    props->ShaderProps.AS.numThreads[2] =
+      ConstMDToUint32(pProps->getOperand(idx++));
+    break;
   default:
     break;
   }
@@ -1123,6 +1145,21 @@ MDTuple *DxilMDHelper::EmitDxilEntryProperties(uint64_t rawShaderFlag,
     MDVals.emplace_back(
         Uint32ToConstMD(props.ShaderProps.Ray.payloadSizeInBytes));
   } break;
+  case DXIL::ShaderKind::Mesh: {
+    auto &MS = props.ShaderProps.MS;
+    MDVals.emplace_back(Uint32ToConstMD(DxilMDHelper::kDxilMSStateTag));
+    MDTuple *pMDTuple = EmitDxilMSState(MS.numThreads,
+                                        MS.maxVertexCount,
+                                        MS.maxPrimitiveCount,
+                                        MS.outputTopology);
+    MDVals.emplace_back(pMDTuple);
+  } break;
+  case DXIL::ShaderKind::Amplification: {
+    auto &AS = props.ShaderProps.AS;
+    MDVals.emplace_back(Uint32ToConstMD(DxilMDHelper::kDxilASStateTag));
+    MDTuple *pMDTuple = EmitDxilASState(AS.numThreads);
+    MDVals.emplace_back(pMDTuple);
+  } break;
   default:
     break;
   }
@@ -1239,6 +1276,17 @@ void DxilMDHelper::LoadDxilEntryProperties(const MDOperand &MDO,
                "else invalid shader kind");
       props.shaderKind = kind;
     } break;
+    case DxilMDHelper::kDxilMSStateTag: {
+      DXASSERT(props.IsMS(), "else invalid shader kind");
+      auto &MS = props.ShaderProps.MS;
+      LoadDxilMSState(MDO, MS.numThreads, MS.maxVertexCount,
+                      MS.maxPrimitiveCount, MS.outputTopology);
+    } break;
+    case DxilMDHelper::kDxilASStateTag: {
+      DXASSERT(props.IsAS(), "else invalid shader kind");
+      auto &AS = props.ShaderProps.AS;
+      LoadDxilASState(MDO, AS.numThreads);
+    } break;
     default:
       DXASSERT(false, "Unknown extended shader properties tag");
       break;
@@ -1307,6 +1355,20 @@ DxilMDHelper::EmitDxilFunctionProps(const hlsl::DxilFunctionProps *props,
     if (bRayAttributes)
       MDVals[valIdx++] = Uint32ToConstMD(props->ShaderProps.Ray.attributeSizeInBytes);
     break;
+  case DXIL::ShaderKind::Mesh:
+    MDVals[valIdx++] = Uint32ToConstMD(props->ShaderProps.MS.numThreads[0]);
+    MDVals[valIdx++] = Uint32ToConstMD(props->ShaderProps.MS.numThreads[1]);
+    MDVals[valIdx++] = Uint32ToConstMD(props->ShaderProps.MS.numThreads[2]);
+    MDVals[valIdx++] = Uint32ToConstMD(props->ShaderProps.MS.maxVertexCount);
+    MDVals[valIdx++] = Uint32ToConstMD(props->ShaderProps.MS.maxPrimitiveCount);
+    MDVals[valIdx++] =
+        Uint8ToConstMD((uint8_t)props->ShaderProps.MS.outputTopology);
+    break;
+  case DXIL::ShaderKind::Amplification:
+    MDVals[valIdx++] = Uint32ToConstMD(props->ShaderProps.AS.numThreads[0]);
+    MDVals[valIdx++] = Uint32ToConstMD(props->ShaderProps.AS.numThreads[1]);
+    MDVals[valIdx++] = Uint32ToConstMD(props->ShaderProps.AS.numThreads[2]);
+    break;
   default:
     break;
   }
@@ -1768,6 +1830,67 @@ void DxilMDHelper::LoadDxilHSState(const MDOperand &MDO,
   MaxTessFactor           = ConstMDToFloat(pTupleMD->getOperand(kDxilHSStateMaxTessellationFactor));
 }
 
+MDTuple *DxilMDHelper::EmitDxilMSState(const unsigned *NumThreads,
+                                       unsigned MaxVertexCount,
+                                       unsigned MaxPrimitiveCount,
+                                       DXIL::MeshOutputTopology OutputTopology) {
+  Metadata *MDVals[kDxilMSStateNumFields];
+  vector<Metadata *> NumThreadVals;
+
+  NumThreadVals.emplace_back(Uint32ToConstMD(NumThreads[0]));
+  NumThreadVals.emplace_back(Uint32ToConstMD(NumThreads[1]));
+  NumThreadVals.emplace_back(Uint32ToConstMD(NumThreads[2]));
+  MDVals[kDxilMSStateNumThreads] = MDNode::get(m_Ctx, NumThreadVals);
+  MDVals[kDxilMSStateMaxVertexCount] = Uint32ToConstMD(MaxVertexCount);
+  MDVals[kDxilMSStateMaxPrimitiveCount] = Uint32ToConstMD(MaxPrimitiveCount);
+  MDVals[kDxilMSStateOutputTopology] = Uint32ToConstMD((unsigned)OutputTopology);
+
+  return MDNode::get(m_Ctx, MDVals);
+}
+
+void DxilMDHelper::LoadDxilMSState(const MDOperand &MDO,
+                                   unsigned *NumThreads,
+                                   unsigned &MaxVertexCount,
+                                   unsigned &MaxPrimitiveCount,
+                                   DXIL::MeshOutputTopology &OutputTopology) {
+  IFTBOOL(MDO.get() != nullptr, DXC_E_INCORRECT_DXIL_METADATA);
+  const MDTuple *pTupleMD = dyn_cast<MDTuple>(MDO.get());
+  IFTBOOL(pTupleMD != nullptr, DXC_E_INCORRECT_DXIL_METADATA);
+  IFTBOOL(pTupleMD->getNumOperands() == kDxilMSStateNumFields, DXC_E_INCORRECT_DXIL_METADATA);
+
+  MDNode *pNode = cast<MDNode>(pTupleMD->getOperand(kDxilMSStateNumThreads));
+  NumThreads[0] = ConstMDToUint32(pNode->getOperand(0));
+  NumThreads[1] = ConstMDToUint32(pNode->getOperand(1));
+  NumThreads[2] = ConstMDToUint32(pNode->getOperand(2));
+  MaxVertexCount = ConstMDToUint32(pTupleMD->getOperand(kDxilMSStateMaxVertexCount));
+  MaxPrimitiveCount = ConstMDToUint32(pTupleMD->getOperand(kDxilMSStateMaxPrimitiveCount));
+  OutputTopology = (DXIL::MeshOutputTopology)ConstMDToUint32(pTupleMD->getOperand(kDxilMSStateOutputTopology));
+}
+
+MDTuple *DxilMDHelper::EmitDxilASState(const unsigned *NumThreads) {
+  Metadata *MDVals[kDxilASStateNumFields];
+  vector<Metadata *> NumThreadVals;
+
+  NumThreadVals.emplace_back(Uint32ToConstMD(NumThreads[0]));
+  NumThreadVals.emplace_back(Uint32ToConstMD(NumThreads[1]));
+  NumThreadVals.emplace_back(Uint32ToConstMD(NumThreads[2]));
+  MDVals[kDxilASStateNumThreads] = MDNode::get(m_Ctx, NumThreadVals);
+
+  return MDNode::get(m_Ctx, MDVals);
+}
+
+void DxilMDHelper::LoadDxilASState(const MDOperand &MDO, unsigned *NumThreads) {
+  IFTBOOL(MDO.get() != nullptr, DXC_E_INCORRECT_DXIL_METADATA);
+  const MDTuple *pTupleMD = dyn_cast<MDTuple>(MDO.get());
+  IFTBOOL(pTupleMD != nullptr, DXC_E_INCORRECT_DXIL_METADATA);
+  IFTBOOL(pTupleMD->getNumOperands() == kDxilASStateNumFields, DXC_E_INCORRECT_DXIL_METADATA);
+
+  MDNode *pNode = cast<MDNode>(pTupleMD->getOperand(kDxilASStateNumThreads));
+  NumThreads[0] = ConstMDToUint32(pNode->getOperand(0));
+  NumThreads[1] = ConstMDToUint32(pNode->getOperand(1));
+  NumThreads[2] = ConstMDToUint32(pNode->getOperand(2));
+}
+
 //
 // DxilExtraPropertyHelper methods.
 //

+ 74 - 4
lib/DXIL/DxilModule.cpp

@@ -627,6 +627,74 @@ void DxilModule::SetMaxTessellationFactor(float MaxTessellationFactor) {
   props.ShaderProps.HS.maxTessFactor = MaxTessellationFactor;
 }
 
+unsigned DxilModule::GetMaxOutputVertices() const {
+  if (!m_pSM->IsMS())
+    return 0;
+  DXASSERT(m_DxilEntryPropsMap.size() == 1, "should have one entry prop");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsMS(), "Must be MS profile");
+  return props.ShaderProps.MS.maxVertexCount;
+}
+
+void DxilModule::SetMaxOutputVertices(unsigned NumOVs) {
+  DXASSERT(m_DxilEntryPropsMap.size() == 1 && m_pSM->IsMS(),
+           "only works for MS profile");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsMS(), "Must be MS profile");
+  props.ShaderProps.MS.maxVertexCount = NumOVs;
+}
+
+unsigned DxilModule::GetMaxOutputPrimitives() const {
+  if (!m_pSM->IsMS())
+    return 0;
+  DXASSERT(m_DxilEntryPropsMap.size() == 1, "should have one entry prop");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsMS(), "Must be MS profile");
+  return props.ShaderProps.MS.maxPrimitiveCount;
+}
+
+void DxilModule::SetMaxOutputPrimitives(unsigned NumOPs) {
+  DXASSERT(m_DxilEntryPropsMap.size() == 1 && m_pSM->IsMS(),
+           "only works for MS profile");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsMS(), "Must be MS profile");
+  props.ShaderProps.MS.maxPrimitiveCount = NumOPs;
+}
+
+DXIL::MeshOutputTopology DxilModule::GetMeshOutputTopology() const {
+  if (!m_pSM->IsMS())
+    return DXIL::MeshOutputTopology::Undefined;
+  DXASSERT(m_DxilEntryPropsMap.size() == 1, "should have one entry prop");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsMS(), "Must be MS profile");
+  return props.ShaderProps.MS.outputTopology;
+}
+
+void DxilModule::SetMeshOutputTopology(DXIL::MeshOutputTopology MeshOutputTopology) {
+  DXASSERT(m_DxilEntryPropsMap.size() == 1 && m_pSM->IsMS(),
+           "only works for MS profile");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsMS(), "Must be MS profile");
+  props.ShaderProps.MS.outputTopology = MeshOutputTopology;
+}
+
+unsigned DxilModule::GetPayloadByteSize() const {
+  if (!m_pSM->IsMS())
+    return 0;
+  DXASSERT(m_DxilEntryPropsMap.size() == 1, "should have one entry prop");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsMS(), "Must be MS profile");
+  return props.ShaderProps.MS.payloadByteSize;
+}
+
+void DxilModule::SetPayloadByteSize(unsigned Size) {
+  DXASSERT(m_DxilEntryPropsMap.size() == 1 && m_pSM->IsMS(),
+           "only works for MS profile");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsMS(), "Must be MS profile");
+  props.ShaderProps.MS.payloadByteSize = Size;
+}
+
 void DxilModule::SetAutoBindingSpace(uint32_t Space) {
   m_AutoBindingSpace = Space;
 }
@@ -651,6 +719,8 @@ void DxilModule::SetShaderProperties(DxilFunctionProps *props) {
   case DXIL::ShaderKind::Domain:
   case DXIL::ShaderKind::Hull:
   case DXIL::ShaderKind::Vertex:
+  case DXIL::ShaderKind::Mesh:
+  case DXIL::ShaderKind::Amplification:
     break;
   default: {
     DXASSERT(props->shaderKind == DXIL::ShaderKind::Geometry,
@@ -929,16 +999,16 @@ const DxilSignature &DxilModule::GetOutputSignature() const {
   return m_DxilEntryPropsMap.begin()->second->sig.OutputSignature;
 }
 
-DxilSignature &DxilModule::GetPatchConstantSignature() {
+DxilSignature &DxilModule::GetPatchConstOrPrimSignature() {
   DXASSERT(m_DxilEntryPropsMap.size() == 1 && !m_pSM->IsLib(),
            "only works for non-lib profile");
-  return m_DxilEntryPropsMap.begin()->second->sig.PatchConstantSignature;
+  return m_DxilEntryPropsMap.begin()->second->sig.PatchConstOrPrimSignature;
 }
 
-const DxilSignature &DxilModule::GetPatchConstantSignature() const {
+const DxilSignature &DxilModule::GetPatchConstOrPrimSignature() const {
   DXASSERT(m_DxilEntryPropsMap.size() == 1 && !m_pSM->IsLib(),
            "only works for non-lib profile");
-  return m_DxilEntryPropsMap.begin()->second->sig.PatchConstantSignature;
+  return m_DxilEntryPropsMap.begin()->second->sig.PatchConstOrPrimSignature;
 }
 
 const std::vector<uint8_t> &DxilModule::GetSerializedRootSignature() const {

+ 42 - 4
lib/DXIL/DxilOperations.cpp

@@ -187,7 +187,7 @@ const OP::OpCodeProperty OP::m_OpCodeProps[(unsigned)OP::OpCode::NumOpCodes] = {
   {  OC::Coverage,                "Coverage",                 OCC::Coverage,                 "coverage",                  { false, false, false, false, false, false, false,  true, false, false, false}, Attribute::ReadNone, },
   {  OC::InnerCoverage,           "InnerCoverage",            OCC::InnerCoverage,            "innerCoverage",             { false, false, false, false, false, false, false,  true, false, false, false}, Attribute::ReadNone, },
 
-  // Compute shader                                                                                                          void,     h,     f,     d,    i1,    i8,   i16,   i32,   i64,   udt,   obj ,  function attribute
+  // Compute/Mesh/Amplification shader                                                                                       void,     h,     f,     d,    i1,    i8,   i16,   i32,   i64,   udt,   obj ,  function attribute
   {  OC::ThreadId,                "ThreadId",                 OCC::ThreadId,                 "threadId",                  { false, false, false, false, false, false, false,  true, false, false, false}, Attribute::ReadNone, },
   {  OC::GroupId,                 "GroupId",                  OCC::GroupId,                  "groupId",                   { false, false, false, false, false, false, false,  true, false, false, false}, Attribute::ReadNone, },
   {  OC::ThreadIdInGroup,         "ThreadIdInGroup",          OCC::ThreadIdInGroup,          "threadIdInGroup",           { false, false, false, false, false, false, false,  true, false, false, false}, Attribute::ReadNone, },
@@ -321,6 +321,16 @@ const OP::OpCodeProperty OP::m_OpCodeProps[(unsigned)OP::OpCode::NumOpCodes] = {
   {  OC::WaveMatch,               "WaveMatch",                OCC::WaveMatch,                "waveMatch",                 { false,  true,  true,  true, false,  true,  true,  true,  true, false, false}, Attribute::None,     },
   {  OC::WaveMultiPrefixOp,       "WaveMultiPrefixOp",        OCC::WaveMultiPrefixOp,        "waveMultiPrefixOp",         { false,  true,  true,  true, false,  true,  true,  true,  true, false, false}, Attribute::None,     },
   {  OC::WaveMultiPrefixBitCount, "WaveMultiPrefixBitCount",  OCC::WaveMultiPrefixBitCount,  "waveMultiPrefixBitCount",   {  true, false, false, false, false, false, false, false, false, false, false}, Attribute::None,     },
+
+  // Mesh shader instructions                                                                                                void,     h,     f,     d,    i1,    i8,   i16,   i32,   i64,   udt,   obj ,  function attribute
+  {  OC::SetMeshOutputCounts,     "SetMeshOutputCounts",      OCC::SetMeshOutputCounts,      "setMeshOutputCounts",       {  true, false, false, false, false, false, false, false, false, false, false}, Attribute::None,     },
+  {  OC::EmitIndices,             "EmitIndices",              OCC::EmitIndices,              "emitIndices",               {  true, false, false, false, false, false, false, false, false, false, false}, Attribute::None,     },
+  {  OC::GetMeshPayload,          "GetMeshPayload",           OCC::GetMeshPayload,           "getMeshPayload",            { false, false, false, false, false, false, false, false, false,  true, false}, Attribute::ReadOnly, },
+  {  OC::StoreVertexOutput,       "StoreVertexOutput",        OCC::StoreVertexOutput,        "storeVertexOutput",         { false,  true,  true, false, false, false,  true,  true, false, false, false}, Attribute::None,     },
+  {  OC::StorePrimitiveOutput,    "StorePrimitiveOutput",     OCC::StorePrimitiveOutput,     "storePrimitiveOutput",      { false,  true,  true, false, false, false,  true,  true, false, false, false}, Attribute::None,     },
+
+  // Amplification shader instructions                                                                                       void,     h,     f,     d,    i1,    i8,   i16,   i32,   i64,   udt,   obj ,  function attribute
+  {  OC::DispatchMesh,            "DispatchMesh",             OCC::DispatchMesh,             "dispatchMesh",              { false, false, false, false, false, false, false, false, false,  true, false}, Attribute::None,     },
 };
 // OPCODE-OLOADS:END
 
@@ -533,7 +543,7 @@ void OP::GetMinShaderModelAndMask(OpCode C, bool bWithTranslation,
   // Instructions: ThreadId=93, GroupId=94, ThreadIdInGroup=95,
   // FlattenedThreadIdInGroup=96
   if ((93 <= op && op <= 96)) {
-    mask = SFLAG(Compute);
+    mask = SFLAG(Compute) | SFLAG(Mesh) | SFLAG(Amplification);
     return;
   }
   // Instructions: DomainLocation=105
@@ -585,7 +595,7 @@ void OP::GetMinShaderModelAndMask(OpCode C, bool bWithTranslation,
   // Instructions: ViewID=138
   if (op == 138) {
     major = 6;  minor = 1;
-    mask = SFLAG(Vertex) | SFLAG(Hull) | SFLAG(Domain) | SFLAG(Geometry) | SFLAG(Pixel);
+    mask = SFLAG(Vertex) | SFLAG(Hull) | SFLAG(Domain) | SFLAG(Geometry) | SFLAG(Pixel) | SFLAG(Mesh);
     return;
   }
   // Instructions: RawBufferLoad=139, RawBufferStore=140
@@ -662,6 +672,19 @@ void OP::GetMinShaderModelAndMask(OpCode C, bool bWithTranslation,
     major = 6;  minor = 5;
     return;
   }
+  // Instructions: DispatchMesh=173
+  if (op == 173) {
+    major = 6;  minor = 5;
+    mask = SFLAG(Amplification);
+    return;
+  }
+  // Instructions: SetMeshOutputCounts=168, EmitIndices=169, GetMeshPayload=170,
+  // StoreVertexOutput=171, StorePrimitiveOutput=172
+  if ((168 <= op && op <= 172)) {
+    major = 6;  minor = 5;
+    mask = SFLAG(Mesh);
+    return;
+  }
   // OPCODE-SMMASK:END
 #undef SFLAG
 }
@@ -928,7 +951,7 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
   case OpCode::Coverage:               A(pI32);     A(pI32); break;
   case OpCode::InnerCoverage:          A(pI32);     A(pI32); break;
 
-    // Compute shader
+    // Compute/Mesh/Amplification shader
   case OpCode::ThreadId:               A(pI32);     A(pI32); A(pI32); break;
   case OpCode::GroupId:                A(pI32);     A(pI32); A(pI32); break;
   case OpCode::ThreadIdInGroup:        A(pI32);     A(pI32); A(pI32); break;
@@ -1062,6 +1085,16 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
   case OpCode::WaveMatch:              A(pI4S);     A(pI32); A(pETy); break;
   case OpCode::WaveMultiPrefixOp:      A(pETy);     A(pI32); A(pETy); A(pI32); A(pI32); A(pI32); A(pI32); A(pI8);  A(pI8);  break;
   case OpCode::WaveMultiPrefixBitCount:A(pI32);     A(pI32); A(pI1);  A(pI32); A(pI32); A(pI32); A(pI32); break;
+
+    // Mesh shader instructions
+  case OpCode::SetMeshOutputCounts:    A(pV);       A(pI32); A(pI32); A(pI32); break;
+  case OpCode::EmitIndices:            A(pV);       A(pI32); A(pI32); A(pI32); A(pI32); A(pI32); break;
+  case OpCode::GetMeshPayload:         A(pETy);     A(pI32); break;
+  case OpCode::StoreVertexOutput:      A(pV);       A(pI32); A(pI32); A(pI32); A(pI8);  A(pETy); A(pI32); break;
+  case OpCode::StorePrimitiveOutput:   A(pV);       A(pI32); A(pI32); A(pI32); A(pI8);  A(pETy); A(pI32); break;
+
+    // Amplification shader instructions
+  case OpCode::DispatchMesh:           A(pV);       A(pI32); A(pI32); A(pI32); A(pI32); A(pETy); break;
   // OPCODE-OLOAD-FUNCS:END
   default: DXASSERT(false, "otherwise unhandled case"); break;
   }
@@ -1152,6 +1185,9 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
   case OpCode::BufferStore:
   case OpCode::StorePatchConstant:
   case OpCode::RawBufferStore:
+  case OpCode::StoreVertexOutput:
+  case OpCode::StorePrimitiveOutput:
+  case OpCode::DispatchMesh:
     DXASSERT_NOMSG(FT->getNumParams() > 4);
     return FT->getParamType(4);
   case OpCode::IsNaN:
@@ -1215,6 +1251,8 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
   case OpCode::IgnoreHit:
   case OpCode::AcceptHitAndEndSearch:
   case OpCode::WaveMultiPrefixBitCount:
+  case OpCode::SetMeshOutputCounts:
+  case OpCode::EmitIndices:
     return Type::getVoidTy(m_Ctx);
   case OpCode::CheckAccessFullyMapped:
   case OpCode::AtomicBinOp:

+ 1 - 0
lib/DXIL/DxilSemantic.cpp

@@ -146,6 +146,7 @@ const Semantic Semantic::ms_SemanticTable[kNumSemanticRecords] = {
   SP(Kind::ViewID,                "SV_ViewID"),
   SP(Kind::Barycentrics,          "SV_Barycentrics"),
   SP(Kind::ShadingRate,           "SV_ShadingRate"),
+  SP(Kind::CullPrimitive,         "SV_CullPrimitive"),
   SP(Kind::Invalid,               nullptr),
 };
 

+ 8 - 3
lib/DXIL/DxilShaderModel.cpp

@@ -43,7 +43,7 @@ bool ShaderModel::operator==(const ShaderModel &other) const {
 
 bool ShaderModel::IsValid() const {
   DXASSERT(IsPS() || IsVS() || IsGS() || IsHS() || IsDS() || IsCS() ||
-               IsLib() || m_Kind == Kind::Invalid,
+               IsLib() || IsMS() || IsAS() || m_Kind == Kind::Invalid,
            "invalid shader model");
   return m_Kind != Kind::Invalid;
 }
@@ -103,7 +103,7 @@ const ShaderModel *ShaderModel::Get(Kind Kind, unsigned Major, unsigned Minor) {
 }
 
 const ShaderModel *ShaderModel::GetByName(const char *pszName) {
-  // [ps|vs|gs|hs|ds|cs]_[major]_[minor]
+  // [ps|vs|gs|hs|ds|cs|ms|as]_[major]_[minor]
   Kind kind;
   switch (pszName[0]) {
   case 'p':   kind = Kind::Pixel;     break;
@@ -113,6 +113,8 @@ const ShaderModel *ShaderModel::GetByName(const char *pszName) {
   case 'd':   kind = Kind::Domain;    break;
   case 'c':   kind = Kind::Compute;   break;
   case 'l':   kind = Kind::Library;   break;
+  case 'm':   kind = Kind::Mesh;      break;
+  case 'a':   kind = Kind::Amplification; break;
   default:    return GetInvalid();
   }
   unsigned Idx = 3;
@@ -241,7 +243,7 @@ void ShaderModel::GetMinValidatorVersion(unsigned &ValMajor, unsigned &ValMinor)
 static const char *ShaderModelKindNames[] = {
     "ps", "vs", "gs", "hs", "ds", "cs", "lib",
     "raygeneration", "intersection", "anyhit", "closesthit", "miss", "callable",
-    "invalid",
+    "ms", "as", "invalid",
 };
 
 const char * ShaderModel::GetKindName() const {
@@ -331,6 +333,9 @@ const ShaderModel ShaderModel::ms_ShaderModels[kNumShaderModels] = {
   // lib_6_x is for offline linking only, and relaxes restrictions
   SM(Kind::Library,  6, kOfflineMinor, "lib_6_x",  32, 32,  true,  true,  UINT_MAX),
 
+  SM(Kind::Mesh,     6, 5, "ms_6_5",    0,  0,  true,  true,  UINT_MAX),
+  SM(Kind::Amplification, 6, 5, "as_6_5", 0, 0, true,  true,  UINT_MAX),
+
   // Values before Invalid must remain sorted by Kind, then Major, then Minor.
 
   SM(Kind::Invalid,  0, 0, "invalid", 0,  0,   false, false, 0),

+ 12 - 1
lib/DXIL/DxilSignature.cpp

@@ -118,13 +118,24 @@ unsigned DxilSignature::NumVectorsUsed(unsigned streamIndex) const {
   return NumVectors;
 }
 
+unsigned DxilSignature::GetRowCount() const {
+  unsigned maxRow = 0;
+  for (auto &E : GetElements()) {
+    unsigned endRow = E->GetStartRow() + E->GetRows();
+    if (maxRow < endRow) {
+      maxRow = endRow;
+    }
+  }
+  return maxRow;
+}
+
 //------------------------------------------------------------------------------
 //
 // EntrySingnature methods.
 //
 DxilEntrySignature::DxilEntrySignature(const DxilEntrySignature &src)
     : InputSignature(src.InputSignature), OutputSignature(src.OutputSignature),
-      PatchConstantSignature(src.PatchConstantSignature) {}
+      PatchConstOrPrimSignature(src.PatchConstOrPrimSignature) {}
 
 } // namespace hlsl
 

+ 2 - 2
lib/DXIL/DxilSignatureElement.cpp

@@ -88,8 +88,8 @@ bool DxilSignatureElement::IsOutput() const {
   return SigPoint::GetSigPoint(m_sigPointKind)->IsOutput();
 }
 
-bool DxilSignatureElement::IsPatchConstant() const {
-  return SigPoint::GetSigPoint(m_sigPointKind)->IsPatchConstant();
+bool DxilSignatureElement::IsPatchConstOrPrim() const {
+  return SigPoint::GetSigPoint(m_sigPointKind)->IsPatchConstOrPrim();
 }
 
 const char *DxilSignatureElement::GetName() const {

+ 22 - 0
lib/DXIL/DxilTypeSystem.cpp

@@ -410,6 +410,28 @@ DXIL::SigPointKind SigPointFromInputQual(DxilParamInputQual Q, DXIL::ShaderKind
       break;
     }
     break;
+  case DXIL::ShaderKind::Mesh:
+    switch (Q) {
+    case DxilParamInputQual::In:
+    case DxilParamInputQual::InPayload:
+      return DXIL::SigPointKind::MSIn;
+    case DxilParamInputQual::OutIndices:
+    case DxilParamInputQual::OutVertices:
+      return DXIL::SigPointKind::MSOut;
+    case DxilParamInputQual::OutPrimitives:
+      return DXIL::SigPointKind::MSPOut;
+    default:
+      break;
+    }
+    break;
+  case DXIL::ShaderKind::Amplification:
+    switch (Q) {
+    case DxilParamInputQual::In:
+      return DXIL::SigPointKind::ASIn;
+    default:
+      break;
+    }
+    break;
   default:
     break;
   }

+ 71 - 25
lib/DxilContainer/DxilContainerAssembler.cpp

@@ -13,6 +13,7 @@
 #include "llvm/ADT/SetVector.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/DebugInfo.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/Bitcode/ReaderWriter.h"
 #include "llvm/Support/MD5.h"
 #include "llvm/ADT/STLExtras.h"
@@ -312,9 +313,9 @@ DxilPartWriter *hlsl::NewProgramSignatureWriter(const DxilModule &M, DXIL::Signa
     return new DxilProgramSignatureWriter(
         M.GetOutputSignature(), domain, false,
         M.GetUseMinPrecision());
-  case DXIL::SignatureKind::PatchConstant:
+  case DXIL::SignatureKind::PatchConstOrPrim:
     return new DxilProgramSignatureWriter(
-        M.GetPatchConstantSignature(), domain,
+        M.GetPatchConstOrPrimSignature(), domain,
         /*IsInput*/ M.GetShaderModel()->IsDS(),
         /*UseMinPrecision*/M.GetUseMinPrecision());
   case DXIL::SignatureKind::Invalid:
@@ -372,7 +373,7 @@ private:
   SmallVector<uint32_t, 8> m_SemanticIndexBuffer;
   std::vector<PSVSignatureElement0> m_SigInputElements;
   std::vector<PSVSignatureElement0> m_SigOutputElements;
-  std::vector<PSVSignatureElement0> m_SigPatchConstantElements;
+  std::vector<PSVSignatureElement0> m_SigPatchConstOrPrimElements;
 
   void SetPSVSigElement(PSVSignatureElement0 &E, const DxilSignatureElement &SE) {
     memset(&E, 0, sizeof(PSVSignatureElement0));
@@ -468,8 +469,8 @@ public:
       m_SigInputElements.resize(m_PSVInitInfo.SigInputElements);
       m_PSVInitInfo.SigOutputElements = m_Module.GetOutputSignature().GetElements().size();
       m_SigOutputElements.resize(m_PSVInitInfo.SigOutputElements);
-      m_PSVInitInfo.SigPatchConstantElements = m_Module.GetPatchConstantSignature().GetElements().size();
-      m_SigPatchConstantElements.resize(m_PSVInitInfo.SigPatchConstantElements);
+      m_PSVInitInfo.SigPatchConstOrPrimElements = m_Module.GetPatchConstOrPrimSignature().GetElements().size();
+      m_SigPatchConstOrPrimElements.resize(m_PSVInitInfo.SigPatchConstOrPrimElements);
       uint32_t i = 0;
       for (auto &SE : m_Module.GetInputSignature().GetElements()) {
         SetPSVSigElement(m_SigInputElements[i++], *(SE.get()));
@@ -479,8 +480,8 @@ public:
         SetPSVSigElement(m_SigOutputElements[i++], *(SE.get()));
       }
       i = 0;
-      for (auto &SE : m_Module.GetPatchConstantSignature().GetElements()) {
-        SetPSVSigElement(m_SigPatchConstantElements[i++], *(SE.get()));
+      for (auto &SE : m_Module.GetPatchConstOrPrimSignature().GetElements()) {
+        SetPSVSigElement(m_SigPatchConstOrPrimElements[i++], *(SE.get()));
       }
       // Set String and SemanticInput Tables
       m_PSVInitInfo.StringTable.Table = m_StringBuffer.data();
@@ -493,12 +494,9 @@ public:
       for (unsigned streamIndex = 0; streamIndex < 4; streamIndex++) {
         m_PSVInitInfo.SigOutputVectors[streamIndex] = m_Module.GetOutputSignature().NumVectorsUsed(streamIndex);
       }
-      m_PSVInitInfo.SigPatchConstantVectors = 0;
-      if (SM->IsHS()) {
-        m_PSVInitInfo.SigPatchConstantVectors = m_Module.GetPatchConstantSignature().NumVectorsUsed(0);
-      }
-      if (SM->IsDS()) {
-        m_PSVInitInfo.SigPatchConstantVectors = m_Module.GetPatchConstantSignature().NumVectorsUsed(0);
+      m_PSVInitInfo.SigPatchConstOrPrimVectors = 0;
+      if (SM->IsHS() || SM->IsDS() || SM->IsMS()) {
+        m_PSVInitInfo.SigPatchConstOrPrimVectors = m_Module.GetPatchConstOrPrimSignature().NumVectorsUsed(0);
       }
     }
     if (!m_PSV.InitNew(m_PSVInitInfo, nullptr, &m_PSVBufferSize)) {
@@ -603,10 +601,58 @@ public:
         }
         break;
       }
+    case ShaderModel::Kind::Amplification:
     case ShaderModel::Kind::Compute:
     case ShaderModel::Kind::Library:
     case ShaderModel::Kind::Invalid:
-      // Compute, Library, and Invalide not relevant to PSVRuntimeInfo0
+      // Amplification, Compute, Library, and Invalide not relevant to PSVRuntimeInfo0
+      break;
+    case ShaderModel::Kind::Mesh:
+      pInfo->MS.MaxOutputVertices = (UINT)m_Module.GetMaxOutputVertices();
+      pInfo->MS.MaxOutputPrimitives = (UINT)m_Module.GetMaxOutputPrimitives();
+      pInfo1->MeshOutputTopology = (UINT)m_Module.GetMeshOutputTopology();
+      Module *mod = m_Module.GetModule();
+      const DataLayout &DL = mod->getDataLayout();
+      unsigned totalByteSize = 0;
+      for (GlobalVariable &GV : mod->globals()) {
+        PointerType *gvPtrType = cast<PointerType>(GV.getType());
+        if (gvPtrType->getAddressSpace() == hlsl::DXIL::kTGSMAddrSpace) {
+          Type *gvType = gvPtrType->getPointerElementType();
+          unsigned byteSize = DL.getTypeAllocSize(gvType);
+          totalByteSize += byteSize;
+        }
+      }
+      pInfo->MS.GroupSharedBytesUsed = totalByteSize;
+
+      const Function *entryFunc = m_Module.GetEntryFunction();
+      unsigned payloadByteSize = 0;
+      for (auto b = entryFunc->begin(), bend = entryFunc->end(); b != bend; ++b) {
+        auto i = b->begin(), iend = b->end();
+        for (; i != iend; ++i) {
+          const Instruction &I = *i;
+
+          // Calls to external functions.
+          const CallInst *CI = dyn_cast<CallInst>(&I);
+          if (CI) {
+            Function *FCalled = CI->getCalledFunction();
+            if (FCalled->isDeclaration()) {
+              Value *opcodeVal = CI->getOperand(0);
+              ConstantInt *OpcodeConst = dyn_cast<ConstantInt>(opcodeVal);
+              unsigned opcode = OpcodeConst->getLimitedValue();
+              DXIL::OpCode dxilOpcode = (DXIL::OpCode)opcode;
+              if (dxilOpcode == DXIL::OpCode::GetMeshPayload) {
+                PointerType *payloadPTy = cast<PointerType>(CI->getType());
+                Type *payloadTy = payloadPTy->getPointerElementType();
+                payloadByteSize = DL.getTypeAllocSize(payloadTy);
+                break;
+              }
+            }
+          }
+        }
+        if (i != iend)
+          break;
+      }
+      pInfo->MS.PayloadSizeInBytes = payloadByteSize;
       break;
     }
 
@@ -689,10 +735,10 @@ public:
         DXASSERT_NOMSG(pOutputElement);
         memcpy(pOutputElement, &m_SigOutputElements[i], sizeof(PSVSignatureElement0));
       }
-      for (unsigned i = 0; i < m_PSV.GetSigPatchConstantElements(); i++) {
-        PSVSignatureElement0 *pPatchConstantElement = m_PSV.GetPatchConstantElement0(i);
-        DXASSERT_NOMSG(pPatchConstantElement);
-        memcpy(pPatchConstantElement, &m_SigPatchConstantElements[i], sizeof(PSVSignatureElement0));
+      for (unsigned i = 0; i < m_PSV.GetSigPatchConstOrPrimElements(); i++) {
+        PSVSignatureElement0 *pPatchConstOrPrimElement = m_PSV.GetPatchConstOrPrimElement0(i);
+        DXASSERT_NOMSG(pPatchConstOrPrimElement);
+        memcpy(pPatchConstOrPrimElement, &m_SigPatchConstOrPrimElements[i], sizeof(PSVSignatureElement0));
       }
 
       // Gather ViewID dependency information
@@ -707,7 +753,7 @@ public:
           if (!SM->IsGS())
             break;
         }
-        if (SM->IsHS()) {
+        if (SM->IsHS() || SM->IsMS()) {
           const uint32_t PCScalars = *(pSrc++);
           pSrc = CopyViewIDState(pSrc, InputScalars, PCScalars, m_PSV.GetViewIDPCOutputMask(), m_PSV.GetInputToPCOutputTable());
         } else if (SM->IsDS()) {
@@ -1491,7 +1537,7 @@ void hlsl::SerializeDxilContainerForModule(DxilModule *pModule,
 
   std::unique_ptr<DxilProgramSignatureWriter> pInputSigWriter = nullptr;
   std::unique_ptr<DxilProgramSignatureWriter> pOutputSigWriter = nullptr;
-  std::unique_ptr<DxilProgramSignatureWriter> pPatchConstantSigWriter = nullptr;
+  std::unique_ptr<DxilProgramSignatureWriter> pPatchConstOrPrimSigWriter = nullptr;
   if (!pModule->GetShaderModel()->IsLib()) {
     DXIL::TessellatorDomain domain = DXIL::TessellatorDomain::Undefined;
     if (pModule->GetShaderModel()->IsHS() || pModule->GetShaderModel()->IsDS())
@@ -1514,15 +1560,15 @@ void hlsl::SerializeDxilContainerForModule(DxilModule *pModule,
                      pOutputSigWriter->write(pStream);
                    });
 
-    pPatchConstantSigWriter = llvm::make_unique<DxilProgramSignatureWriter>(
-        pModule->GetPatchConstantSignature(), domain,
+    pPatchConstOrPrimSigWriter = llvm::make_unique<DxilProgramSignatureWriter>(
+        pModule->GetPatchConstOrPrimSignature(), domain,
         /*IsInput*/ pModule->GetShaderModel()->IsDS(),
         /*UseMinPrecision*/ pModule->GetUseMinPrecision());
-    if (pModule->GetPatchConstantSignature().GetElements().size()) {
+    if (pModule->GetPatchConstOrPrimSignature().GetElements().size()) {
       writer.AddPart(DFCC_PatchConstantSignature,
-                     pPatchConstantSigWriter->size(),
+                     pPatchConstOrPrimSigWriter->size(),
                      [&](AbstractMemoryStream *pStream) {
-                       pPatchConstantSigWriter->write(pStream);
+                       pPatchConstOrPrimSigWriter->write(pStream);
                      });
     }
   }

+ 1 - 1
lib/DxilPIXPasses/DxilShaderAccessTracking.cpp

@@ -453,7 +453,7 @@ bool DxilShaderAccessTracking::runOnModule(Module &M)
 
 
     // todo: should "GetDimensions" mean a resource access?
-    static_assert(DXIL::OpCode::NumOpCodes == static_cast<DXIL::OpCode>(168), "Please update PIX passes if any resource access opcodes are added");
+    static_assert(DXIL::OpCode::NumOpCodes == static_cast<DXIL::OpCode>(174), "Please update PIX passes if any resource access opcodes are added");
     ResourceAccessFunction raFunctions[] = {
       { DXIL::OpCode::CBufferLoadLegacy     , ShaderAccessFlags::Read   , false, f32i32f64 },
       { DXIL::OpCode::CBufferLoad           , ShaderAccessFlags::Read   , false, f16f32f64i16i32i64 },

+ 46 - 33
lib/HLSL/ComputeViewIdState.cpp

@@ -42,11 +42,11 @@ DxilViewIdState::DxilViewIdState(DxilModule *pDxilModule)
     : m_pModule(pDxilModule) {}
 unsigned DxilViewIdState::getNumInputSigScalars() const                   { return m_NumInputSigScalars; }
 unsigned DxilViewIdState::getNumOutputSigScalars(unsigned StreamId) const { return m_NumOutputSigScalars[StreamId]; }
-unsigned DxilViewIdState::getNumPCSigScalars() const                      { return m_NumPCSigScalars; }
+unsigned DxilViewIdState::getNumPCSigScalars() const                      { return m_NumPCOrPrimSigScalars; }
 const DxilViewIdState::OutputsDependentOnViewIdType   &DxilViewIdState::getOutputsDependentOnViewId(unsigned StreamId) const    { return m_OutputsDependentOnViewId[StreamId]; }
-const DxilViewIdState::OutputsDependentOnViewIdType   &DxilViewIdState::getPCOutputsDependentOnViewId() const                   { return m_PCOutputsDependentOnViewId; }
+const DxilViewIdState::OutputsDependentOnViewIdType   &DxilViewIdState::getPCOutputsDependentOnViewId() const                   { return m_PCOrPrimOutputsDependentOnViewId; }
 const DxilViewIdState::InputsContributingToOutputType &DxilViewIdState::getInputsContributingToOutputs(unsigned StreamId) const { return m_InputsContributingToOutputs[StreamId]; }
-const DxilViewIdState::InputsContributingToOutputType &DxilViewIdState::getInputsContributingToPCOutputs() const                { return m_InputsContributingToPCOutputs; }
+const DxilViewIdState::InputsContributingToOutputType &DxilViewIdState::getInputsContributingToPCOutputs() const                { return m_InputsContributingToPCOrPrimOutputs; }
 const DxilViewIdState::InputsContributingToOutputType &DxilViewIdState::getPCInputsContributingToOutputs() const                { return m_PCInputsContributingToOutputs; }
 
 namespace {
@@ -95,39 +95,52 @@ void DxilViewIdState::PrintSets(llvm::raw_ostream &OS) {
   const ShaderModel *pSM = m_pModule->GetShaderModel();
   OS << "ViewId state: \n";
 
-  if (!pSM->IsGS()) {
-    OS << "Number of inputs: " << m_NumInputSigScalars     << 
-                 ", outputs: " << m_NumOutputSigScalars[0] << 
-              ", patchconst: " << m_NumPCSigScalars        << "\n";
-  } else {
+  if (pSM->IsGS()) {
     OS << "Number of inputs: "   << m_NumInputSigScalars     << 
                  ", outputs: { " << m_NumOutputSigScalars[0] << ", " << m_NumOutputSigScalars[1] << ", " <<
                                     m_NumOutputSigScalars[2] << ", " << m_NumOutputSigScalars[3] << " }" <<
-              ", patchconst: "   << m_NumPCSigScalars        << "\n";
+              ", patchconst: "   << m_NumPCOrPrimSigScalars        << "\n";
+  } else if (pSM->IsMS()) {
+    OS << "Number of inputs: " << m_NumInputSigScalars <<
+      ", vertex outputs: " << m_NumOutputSigScalars[0] <<
+      ", primitive outputs: " << m_NumPCOrPrimSigScalars << "\n";
+  } else {
+    OS << "Number of inputs: " << m_NumInputSigScalars <<
+      ", outputs: " << m_NumOutputSigScalars[0] <<
+      ", patchconst: " << m_NumPCOrPrimSigScalars << "\n";
   }
 
-  if (!pSM->IsGS()) {
-    PrintOutputsDependentOnViewId(OS, "Outputs", m_NumOutputSigScalars[0], m_OutputsDependentOnViewId[0]);
-  } else {
+  if (pSM->IsGS()) {
     PrintOutputsDependentOnViewId(OS, "Outputs for Stream0", m_NumOutputSigScalars[0], m_OutputsDependentOnViewId[0]);
     PrintOutputsDependentOnViewId(OS, "Outputs for Stream1", m_NumOutputSigScalars[1], m_OutputsDependentOnViewId[1]);
     PrintOutputsDependentOnViewId(OS, "Outputs for Stream2", m_NumOutputSigScalars[2], m_OutputsDependentOnViewId[2]);
     PrintOutputsDependentOnViewId(OS, "Outputs for Stream3", m_NumOutputSigScalars[3], m_OutputsDependentOnViewId[3]);
+  } else if (pSM->IsMS()) {
+    PrintOutputsDependentOnViewId(OS, "Vertex Outputs", m_NumOutputSigScalars[0], m_OutputsDependentOnViewId[0]);
+  } else {
+    PrintOutputsDependentOnViewId(OS, "Outputs", m_NumOutputSigScalars[0], m_OutputsDependentOnViewId[0]);
   }
+
   if (pSM->IsHS()) {
-    PrintOutputsDependentOnViewId(OS, "PCOutputs", m_NumPCSigScalars, m_PCOutputsDependentOnViewId);
+    PrintOutputsDependentOnViewId(OS, "PCOutputs", m_NumPCOrPrimSigScalars, m_PCOrPrimOutputsDependentOnViewId);
+  } else if (pSM->IsMS()) {
+    PrintOutputsDependentOnViewId(OS, "Primitive Outputs", m_NumPCOrPrimSigScalars, m_PCOrPrimOutputsDependentOnViewId);
   }
 
-  if (!pSM->IsGS()) {
-    PrintInputsContributingToOutputs(OS, "Inputs", "Outputs", m_InputsContributingToOutputs[0]);
-  } else {
+  if (pSM->IsGS()) {
     PrintInputsContributingToOutputs(OS, "Inputs", "Outputs for Stream0", m_InputsContributingToOutputs[0]);
     PrintInputsContributingToOutputs(OS, "Inputs", "Outputs for Stream1", m_InputsContributingToOutputs[1]);
     PrintInputsContributingToOutputs(OS, "Inputs", "Outputs for Stream2", m_InputsContributingToOutputs[2]);
     PrintInputsContributingToOutputs(OS, "Inputs", "Outputs for Stream3", m_InputsContributingToOutputs[3]);
+  } else if (pSM->IsMS()) {
+    PrintInputsContributingToOutputs(OS, "Inputs", "Vertex Outputs", m_InputsContributingToOutputs[0]);
+  } else {
+    PrintInputsContributingToOutputs(OS, "Inputs", "Outputs", m_InputsContributingToOutputs[0]);
   }
   if (pSM->IsHS()) {
-    PrintInputsContributingToOutputs(OS, "Inputs", "PCOutputs", m_InputsContributingToPCOutputs);
+    PrintInputsContributingToOutputs(OS, "Inputs", "PCOutputs", m_InputsContributingToPCOrPrimOutputs);
+  } else if (pSM->IsMS()) {
+    PrintInputsContributingToOutputs(OS, "Inputs", "Primitive Outputs", m_InputsContributingToPCOrPrimOutputs);
   } else if (pSM->IsDS()) {
     PrintInputsContributingToOutputs(OS, "PCInputs", "Outputs", m_PCInputsContributingToOutputs);
   }
@@ -141,9 +154,9 @@ void DxilViewIdState::Clear() {
     m_OutputsDependentOnViewId[i].reset();
     m_InputsContributingToOutputs[i].clear();
   }
-  m_NumPCSigScalars     = 0;
-  m_PCOutputsDependentOnViewId.reset();
-  m_InputsContributingToPCOutputs.clear();
+  m_NumPCOrPrimSigScalars     = 0;
+  m_PCOrPrimOutputsDependentOnViewId.reset();
+  m_InputsContributingToPCOrPrimOutputs.clear();
   m_PCInputsContributingToOutputs.clear();
   m_SerializedState.clear();
 }
@@ -212,15 +225,15 @@ void DxilViewIdState::Serialize() {
     }
     Size += NumInputs * NumOutUINTs; // m_InputsContributingToOutputs[StreamId]
   }
-  if (pSM->IsHS() || pSM->IsDS()) {
+  if (pSM->IsHS() || pSM->IsDS() || pSM->IsMS()) {
     Size += 1; // #PatchConstant.
     unsigned NumPCs = getNumPCSigScalars();
     unsigned NumPCUINTs = RoundUpToUINT(NumPCs);
-    if (pSM->IsHS()) {
+    if (pSM->IsHS() || pSM->IsMS()) {
       if (m_bUsesViewId) {
-        Size += NumPCUINTs; // m_PCOutputsDependentOnViewId
+        Size += NumPCUINTs; // m_PCOrPrimOutputsDependentOnViewId
       }
-      Size += NumInputs * NumPCUINTs; // m_InputsContributingToPCOutputs
+      Size += NumInputs * NumPCUINTs; // m_InputsContributingToPCOrPrimOutputs
     } else {
       unsigned NumOutputs = getNumOutputSigScalars(0);
       unsigned NumOutUINTs = RoundUpToUINT(NumOutputs);
@@ -244,16 +257,16 @@ void DxilViewIdState::Serialize() {
     SerializeInputsContributingToOutput(
         NumInputs, NumOutputs, m_InputsContributingToOutputs[StreamId], pData);
   }
-  if (pSM->IsHS() || pSM->IsDS()) {
+  if (pSM->IsHS() || pSM->IsDS() || pSM->IsMS()) {
     unsigned NumPCs = getNumPCSigScalars();
     *pData++ = NumPCs;
-    if (pSM->IsHS()) {
+    if (pSM->IsHS() || pSM->IsMS()) {
       if (m_bUsesViewId) {
-        SerializeOutputsDependentOnViewId(NumPCs, m_PCOutputsDependentOnViewId,
+        SerializeOutputsDependentOnViewId(NumPCs, m_PCOrPrimOutputsDependentOnViewId,
                                           pData);
       }
       SerializeInputsContributingToOutput(
-          NumInputs, NumPCs, m_InputsContributingToPCOutputs, pData);
+          NumInputs, NumPCs, m_InputsContributingToPCOrPrimOutputs, pData);
     } else {
       unsigned NumOutputs = getNumOutputSigScalars(0);
       SerializeInputsContributingToOutput(
@@ -348,18 +361,18 @@ void DxilViewIdState::Deserialize(const unsigned *pData,
         &pData[ConsumedUINTs], DataSizeInUINTs - ConsumedUINTs);
   }
 
-  if (pSM->IsHS() || pSM->IsDS()) {
+  if (pSM->IsHS() || pSM->IsDS() || pSM->IsMS()) {
     IFTBOOL(DataSizeInUINTs - ConsumedUINTs >= 1, DXC_E_GENERAL_INTERNAL_ERROR);
     unsigned NumPCs = pData[ConsumedUINTs++];
-    m_NumPCSigScalars = NumPCs;
-    if (pSM->IsHS()) {
+    m_NumPCOrPrimSigScalars = NumPCs;
+    if (pSM->IsHS() || pSM->IsMS()) {
       if (m_bUsesViewId) {
         ConsumedUINTs += DeserializeOutputsDependentOnViewId(
-            NumPCs, m_PCOutputsDependentOnViewId, &pData[ConsumedUINTs],
+            NumPCs, m_PCOrPrimOutputsDependentOnViewId, &pData[ConsumedUINTs],
             DataSizeInUINTs - ConsumedUINTs);
       }
       ConsumedUINTs += DeserializeInputsContributingToOutput(
-          NumInputs, NumPCs, m_InputsContributingToPCOutputs,
+          NumInputs, NumPCs, m_InputsContributingToPCOrPrimOutputs,
           &pData[ConsumedUINTs], DataSizeInUINTs - ConsumedUINTs);
     } else {
       unsigned NumOutputs = getNumOutputSigScalars(0);

+ 36 - 17
lib/HLSL/ComputeViewIdStateBuilder.cpp

@@ -7,7 +7,9 @@
 //                                                                           //
 ///////////////////////////////////////////////////////////////////////////////
 
+#include "dxc/HlslIntrinsicOp.h"
 #include "dxc/HLSL/ComputeViewIdState.h"
+#include "dxc/HLSL/HLOperations.h"
 #include "dxc/Support/Global.h"
 #include "dxc/DXIL/DxilModule.h"
 #include "dxc/DXIL/DxilOperations.h"
@@ -52,13 +54,13 @@ public:
         m_NumInputSigScalars(state.m_NumInputSigScalars),
         m_NumOutputSigScalars(state.m_NumOutputSigScalars,
                               DxilViewIdStateData::kNumStreams),
-        m_NumPCSigScalars(state.m_NumPCSigScalars),
+        m_NumPCOrPrimSigScalars(state.m_NumPCOrPrimSigScalars),
         m_OutputsDependentOnViewId(state.m_OutputsDependentOnViewId,
                                    DxilViewIdStateData::kNumStreams),
-        m_PCOutputsDependentOnViewId(state.m_PCOutputsDependentOnViewId),
+        m_PCOrPrimOutputsDependentOnViewId(state.m_PCOrPrimOutputsDependentOnViewId),
         m_InputsContributingToOutputs(state.m_InputsContributingToOutputs,
                                       DxilViewIdStateData::kNumStreams),
-        m_InputsContributingToPCOutputs(state.m_InputsContributingToPCOutputs),
+        m_InputsContributingToPCOrPrimOutputs(state.m_InputsContributingToPCOrPrimOutputs),
         m_PCInputsContributingToOutputs(state.m_PCInputsContributingToOutputs),
         m_bUsesViewId(state.m_bUsesViewId) {}
 
@@ -71,15 +73,15 @@ private:
 
   unsigned &m_NumInputSigScalars;
   MutableArrayRef<unsigned> m_NumOutputSigScalars;
-  unsigned &m_NumPCSigScalars;
+  unsigned &m_NumPCOrPrimSigScalars;
 
   // Set of scalar outputs dependent on ViewID.
   MutableArrayRef<OutputsDependentOnViewIdType> m_OutputsDependentOnViewId;
-  OutputsDependentOnViewIdType &m_PCOutputsDependentOnViewId;
+  OutputsDependentOnViewIdType &m_PCOrPrimOutputsDependentOnViewId;
 
   // Set of scalar inputs contributing to computation of scalar outputs.
   MutableArrayRef<InputsContributingToOutputType> m_InputsContributingToOutputs;
-  InputsContributingToOutputType &m_InputsContributingToPCOutputs; // HS PC only.
+  InputsContributingToOutputType &m_InputsContributingToPCOrPrimOutputs; // HS PC and MS Prim only.
   InputsContributingToOutputType &m_PCInputsContributingToOutputs; // DS only.
 
   bool &m_bUsesViewId;
@@ -174,7 +176,7 @@ void DxilViewIdStateBuilder::Compute() {
   // 1. Traverse signature MD to determine max packed location.
   DetermineMaxPackedLocation(m_pModule->GetInputSignature(), &m_NumInputSigScalars, 1);
   DetermineMaxPackedLocation(m_pModule->GetOutputSignature(), &m_NumOutputSigScalars[0], pSM->IsGS() ? kNumStreams : 1);
-  DetermineMaxPackedLocation(m_pModule->GetPatchConstantSignature(), &m_NumPCSigScalars, 1);
+  DetermineMaxPackedLocation(m_pModule->GetPatchConstOrPrimSignature(), &m_NumPCOrPrimSigScalars, 1);
 
   // 2. Collect sets of functions reachable from main and pc entries.
   CallGraphAnalysis CGA;
@@ -205,10 +207,10 @@ void DxilViewIdStateBuilder::Compute() {
                      m_OutputsDependentOnViewId[StreamId],
                      m_InputsContributingToOutputs[StreamId], false);
   }
-  if (pSM->IsHS()) {
+  if (pSM->IsHS() || pSM->IsMS()) {
     CreateViewIdSets(m_PCEntry.ContributingInstructions[0],
-                     m_PCOutputsDependentOnViewId,
-                     m_InputsContributingToPCOutputs, true);
+                     m_PCOrPrimOutputsDependentOnViewId,
+                     m_InputsContributingToPCOrPrimOutputs, true);
   } else if (pSM->IsDS()) {
     OutputsDependentOnViewIdType OutputsDependentOnViewId;
     CreateViewIdSets(m_Entry.ContributingInstructions[0],
@@ -233,12 +235,12 @@ void DxilViewIdStateBuilder::Clear() {
     m_OutputsDependentOnViewId[i].reset();
     m_InputsContributingToOutputs[i].clear();
   }
-  m_NumPCSigScalars     = 0;
+  m_NumPCOrPrimSigScalars     = 0;
   m_InpSigDynIdxElems.clear();
   m_OutSigDynIdxElems.clear();
   m_PCSigDynIdxElems.clear();
-  m_PCOutputsDependentOnViewId.reset();
-  m_InputsContributingToPCOutputs.clear();
+  m_PCOrPrimOutputsDependentOnViewId.reset();
+  m_InputsContributingToPCOrPrimOutputs.clear();
   m_PCInputsContributingToOutputs.clear();
   m_Entry.Clear();
   m_PCEntry.Clear();
@@ -340,6 +342,12 @@ void DxilViewIdStateBuilder::AnalyzeFunctions(EntryInfo &Entry) {
           GetUnsignedVal(SO.get_rowIndex(), (uint32_t*)&row);
           IFTBOOL(GetUnsignedVal(SO.get_colIndex(), &col), DXC_E_GENERAL_INTERNAL_ERROR);
           Entry.Outputs.emplace(CI);
+        } else if (DxilInst_StorePOutput SPO = DxilInst_StorePOutput(CI)) {
+          pDynIdxElems = &m_PCSigDynIdxElems;
+          IFTBOOL(GetUnsignedVal(SPO.get_outputSigId(), &id), DXC_E_GENERAL_INTERNAL_ERROR);
+          GetUnsignedVal(SPO.get_rowIndex(), (uint32_t*)&row);
+          IFTBOOL(GetUnsignedVal(SPO.get_colIndex(), &col), DXC_E_GENERAL_INTERNAL_ERROR);
+          Entry.Outputs.emplace(CI);
         } else if (DxilInst_LoadPatchConstant LPC = DxilInst_LoadPatchConstant(CI)) {
           if (m_pModule->GetShaderModel()->IsDS()) {
             pDynIdxElems = &m_PCSigDynIdxElems;
@@ -412,8 +420,14 @@ void DxilViewIdStateBuilder::CollectValuesContributingToOutputs(EntryInfo &Entry
       GetUnsignedVal(SO.get_outputSigId(), &id);
       GetUnsignedVal(SO.get_colIndex(), &col);
       GetUnsignedVal(SO.get_rowIndex(), (uint32_t*)&startRow);
+    } else if (DxilInst_StorePOutput SPO = DxilInst_StorePOutput(CI)) {
+      pDxilSig = &m_pModule->GetPatchConstOrPrimSignature();
+      pContributingValue = SPO.get_value();
+      GetUnsignedVal(SPO.get_outputSigId(), &id);
+      GetUnsignedVal(SPO.get_colIndex(), &col);
+      GetUnsignedVal(SPO.get_rowIndex(), (uint32_t*)&startRow);
     } else if (DxilInst_StorePatchConstant SPC = DxilInst_StorePatchConstant(CI)) {
-      pDxilSig = &m_pModule->GetPatchConstantSignature();
+      pDxilSig = &m_pModule->GetPatchConstOrPrimSignature();
       pContributingValue = SPC.get_value();
       GetUnsignedVal(SPC.get_outputSigID(), &id);
       GetUnsignedVal(SPC.get_row(), (uint32_t*)&startRow);
@@ -669,6 +683,11 @@ void DxilViewIdStateBuilder::CollectReachingDeclsRec(Value *pValue, ValueSetType
     CollectReachingDeclsRec(SelI->getFalseValue(), ReachingDecls, Visited);
   } else if (dyn_cast<Argument>(pValue)) {
     ReachingDecls.emplace(pValue);
+  } else if (CallInst *call = dyn_cast<CallInst>(pValue)) {
+    Function *func = call->getCalledFunction();
+    StringRef funcName = func->getName();
+    DXASSERT(funcName.startswith("dx.op.getMeshPayload"), "the function must be @dx.op.getMeshPayload here.");
+    ReachingDecls.emplace(pValue);
   } else {
     IFT(DXC_E_GENERAL_INTERNAL_ERROR);
   }
@@ -762,7 +781,7 @@ void DxilViewIdStateBuilder::CreateViewIdSets(const std::unordered_map<unsigned,
           GetUnsignedVal(LPC.get_inputSigId(), &inpId);
           GetUnsignedVal(LPC.get_col(), &col);
           GetUnsignedVal(LPC.get_row(), (uint32_t*)&startRow);
-          pSigElem = &m_pModule->GetPatchConstantSignature().GetElement(inpId);
+          pSigElem = &m_pModule->GetPatchConstOrPrimSignature().GetElement(inpId);
         }
       } else {
         continue;
@@ -787,7 +806,7 @@ void DxilViewIdStateBuilder::CreateViewIdSets(const std::unordered_map<unsigned,
             // This HS patch-constant output depends on an input value of LoadOutputControlPoint
             // that is the output value of the HS main (control-point) function.
             // Transitively update this (patch-constant) output dependence on main (control-point) output.
-            DXASSERT_NOMSG(&OutputsDependentOnViewId == &m_PCOutputsDependentOnViewId);
+            DXASSERT_NOMSG(&OutputsDependentOnViewId == &m_PCOrPrimOutputsDependentOnViewId);
             OutputsDependentOnViewId[outIdx] = OutputsDependentOnViewId[outIdx] || m_OutputsDependentOnViewId[0][index];
 
             const auto it = m_InputsContributingToOutputs[0].find(index);
@@ -813,7 +832,7 @@ unsigned DxilViewIdStateBuilder::GetLinearIndex(DxilSignatureElement &SigElem, i
 void DxilViewIdStateBuilder::UpdateDynamicIndexUsageState() const {
   UpdateDynamicIndexUsageStateForSig(m_pModule->GetInputSignature(), m_InpSigDynIdxElems);
   UpdateDynamicIndexUsageStateForSig(m_pModule->GetOutputSignature(), m_OutSigDynIdxElems);
-  UpdateDynamicIndexUsageStateForSig(m_pModule->GetPatchConstantSignature(), m_PCSigDynIdxElems);
+  UpdateDynamicIndexUsageStateForSig(m_pModule->GetPatchConstOrPrimSignature(), m_PCSigDynIdxElems);
 }
 
 void DxilViewIdStateBuilder::UpdateDynamicIndexUsageStateForSig(DxilSignature &Sig,

+ 3 - 3
lib/HLSL/DxilContainerReflection.cpp

@@ -1854,7 +1854,7 @@ HRESULT DxilShaderReflection::Load(IDxcBlob *pBlob,
     // Populate input/output/patch constant signatures.
     CreateReflectionObjectsForSignature(m_pDxilModule->GetInputSignature(), m_InputSignature);
     CreateReflectionObjectsForSignature(m_pDxilModule->GetOutputSignature(), m_OutputSignature);
-    CreateReflectionObjectsForSignature(m_pDxilModule->GetPatchConstantSignature(), m_PatchConstantSignature);
+    CreateReflectionObjectsForSignature(m_pDxilModule->GetPatchConstOrPrimSignature(), m_PatchConstantSignature);
     MarkUsedSignatureElements();
     return S_OK;
   }
@@ -1969,14 +1969,14 @@ void DxilShaderReflection::MarkUsedSignatureElements() {
       if (!GetUnsignedVal(SPC.get_col(), &col)) continue;
       if (!GetUnsignedVal(SPC.get_row(), &row)) continue;
       pDescs = &m_PatchConstantSignature;
-      pSig = &m_pDxilModule->GetPatchConstantSignature();
+      pSig = &m_pDxilModule->GetPatchConstOrPrimSignature();
     }
     else if (LPC) {
       if (!GetUnsignedVal(LPC.get_inputSigId(), &sigId)) continue;
       if (!GetUnsignedVal(LPC.get_col(), &col)) continue;
       if (!GetUnsignedVal(LPC.get_row(), &row)) continue;
       pDescs = &m_PatchConstantSignature;
-      pSig = &m_pDxilModule->GetPatchConstantSignature();
+      pSig = &m_pDxilModule->GetPatchConstOrPrimSignature();
     }
     else {
       continue;

+ 16 - 5
lib/HLSL/DxilPreserveAllOutputs.cpp

@@ -31,7 +31,10 @@ public:
   explicit OutputWrite(CallInst *call)
     : m_Call(call)
   {
-    assert(DxilInst_StoreOutput(call) || DxilInst_StorePatchConstant(call));
+    assert(DxilInst_StoreOutput(call) ||
+           DxilInst_StoreVertexOutput(call) ||
+           DxilInst_StorePrimitiveOutput(call) ||
+           DxilInst_StorePatchConstant(call));
   }
 
   unsigned GetSignatureID() const {
@@ -40,7 +43,7 @@ public:
   }
 
   DxilSignatureElement &GetSignatureElement(DxilModule &DM) const {
-    if (DxilInst_StorePatchConstant(m_Call))
+    if (DxilInst_StorePatchConstant(m_Call) || DxilInst_StorePrimitiveOutput(m_Call))
       return DM.GetPatchConstantSignature().GetElement(GetSignatureID());
     else
       return DM.GetOutputSignature().GetElement(GetSignatureID());
@@ -169,8 +172,14 @@ private:
   }
 
   DXIL::OpCode GetOutputOpCode() const {
-    if (m_OutputElement.IsPatchConstant())
-      return DXIL::OpCode::StorePatchConstant;
+    if (m_OutputElement.IsPatchConstOrPrim()) {
+      if (m_OutputElement.GetSigPointKind() == DXIL::SigPointKind::PCOut)
+        return DXIL::OpCode::StorePatchConstant;
+      else {
+        assert(m_OutputElement.GetSigPointKind() == DXIL::SigPointKind::MSPOut);
+        return DXIL::OpCode::StorePrimitiveOutput;
+      }
+    }
     else
       return DXIL::OpCode::StoreOutput;
   }
@@ -228,9 +237,11 @@ DxilPreserveAllOutputs::OutputVec DxilPreserveAllOutputs::collectOutputStores(Fu
   for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) {
     Instruction *inst = &*I;
     DxilInst_StoreOutput storeOutput(inst);
+    DxilInst_StoreVertexOutput storeVertexOutput(inst);
+    DxilInst_StorePrimitiveOutput storePrimitiveOutput(inst);
     DxilInst_StorePatchConstant storePatch(inst);
 
-    if (storeOutput || storePatch)
+    if (storeOutput || storeVertexOutput || storePrimitiveOutput || storePatch)
       calls.emplace_back(cast<CallInst>(inst));
   }
   return calls;

+ 429 - 41
lib/HLSL/DxilValidation.cpp

@@ -191,6 +191,12 @@ const char *hlsl::GetValidationRuleText(ValidationRule value) {
     case hlsl::ValidationRule::InstrAttributeAtVertexNoInterpolation: return "Attribute %0 must have nointerpolation mode in order to use GetAttributeAtVertex function.";
     case hlsl::ValidationRule::InstrCreateHandleImmRangeID: return "Local resource must map to global resource.";
     case hlsl::ValidationRule::InstrSignatureOperationNotInEntry: return "Dxil operation for input output signature must be in entryPoints.";
+    case hlsl::ValidationRule::InstrMultipleSetMeshOutputCounts: return "SetMeshOUtputCounts cannot be called multiple times.";
+    case hlsl::ValidationRule::InstrMissingSetMeshOutputCounts: return "Missing SetMeshOutputCounts call.";
+    case hlsl::ValidationRule::InstrNonDominatingSetMeshOutputCounts: return "Non-Dominating SetMeshOutputCounts call.";
+    case hlsl::ValidationRule::InstrMultipleGetMeshPayload: return "GetMeshPayload cannot be called multiple times.";
+    case hlsl::ValidationRule::InstrNotOnceDispatchMesh: return "DispatchMesh must be called exactly once in an Amplification shader.";
+    case hlsl::ValidationRule::InstrNonDominatingDispatchMesh: return "Non-Dominating DispatchMesh call.";
     case hlsl::ValidationRule::TypesNoVector: return "Vector type '%0' is not allowed";
     case hlsl::ValidationRule::TypesDefined: return "Type '%0' is not defined on DXIL primitives";
     case hlsl::ValidationRule::TypesIntWidth: return "Int type '%0' has an invalid width";
@@ -202,6 +208,7 @@ const char *hlsl::GetValidationRuleText(ValidationRule value) {
     case hlsl::ValidationRule::SmOperand: return "Operand must be defined in target shader model";
     case hlsl::ValidationRule::SmSemantic: return "Semantic '%0' is invalid as %1 %2";
     case hlsl::ValidationRule::SmNoInterpMode: return "Interpolation mode for '%0' is set but should be undefined";
+    case hlsl::ValidationRule::SmConstantInterpMode: return "Interpolation mode for '%0' should be constant";
     case hlsl::ValidationRule::SmNoPSOutputIdx: return "Pixel shader output registers are not indexable.";
     case hlsl::ValidationRule::SmPSConsistentInterp: return "Interpolation mode for PS input position must be linear_noperspective_centroid or linear_noperspective_sample when outputting oDepthGE or oDepthLE and not running at sample frequency (which is forced by inputting SV_SampleIndex or declaring an input linear_sample or linear_noperspective_sample)";
     case hlsl::ValidationRule::SmThreadGroupChannelRange: return "Declared Thread Group %0 size %1 outside valid range [%2..%3]";
@@ -253,6 +260,16 @@ const char *hlsl::GetValidationRuleText(ValidationRule value) {
     case hlsl::ValidationRule::Sm64bitRawBufferLoadStore: return "i64/f64 rawBufferLoad/Store overloads are allowed after SM 6.3";
     case hlsl::ValidationRule::SmRayShaderSignatures: return "Ray tracing shader '%0' should not have any shader signatures";
     case hlsl::ValidationRule::SmRayShaderPayloadSize: return "For shader '%0', %1 size is smaller than argument's allocation size";
+    case hlsl::ValidationRule::SmMeshShaderMaxVertexCount: return "MS max vertex output count must be [0..%0].  %1 specified";
+    case hlsl::ValidationRule::SmMeshShaderMaxPrimitiveCount: return "MS max primitive output count must be [0..%0].  %1 specified";
+    case hlsl::ValidationRule::SmMeshShaderPayloadSize: return "For shader '%0', payload size is greater than %1";
+    case hlsl::ValidationRule::SmMeshShaderOutputSize: return "For shader '%0', vertex plus primitive output size is greater than %1";
+    case hlsl::ValidationRule::SmMeshShaderInOutSize: return "For shader '%0', input plus output size is greater than %1";
+    case hlsl::ValidationRule::SmMeshVSigRowCount: return "For shader '%0', vertex output signatures are taking up more than %1 rows";
+    case hlsl::ValidationRule::SmMeshPSigRowCount: return "For shader '%0', primitive output signatures are taking up more than %1 rows";
+    case hlsl::ValidationRule::SmMeshTotalSigRowCount: return "For shader '%0', vertex and primitive output signatures are taking up more than %1 rows";
+    case hlsl::ValidationRule::SmMaxMSSMSize: return "Total Thread Group Shared Memory storage is %0, exceeded %1";
+    case hlsl::ValidationRule::SmAmplificationShaderPayloadSize: return "For shader '%0', payload size is greater than %1";
     case hlsl::ValidationRule::UniNoWaveSensitiveGradient: return "Gradient operations are not affected by wave-sensitive data or control flow.";
     case hlsl::ValidationRule::FlowReducible: return "Execution flow must be reducible";
     case hlsl::ValidationRule::FlowNoRecusion: return "Recursion is not permitted";
@@ -356,7 +373,7 @@ struct EntryStatus {
   bool hasOutputPosition[DXIL::kNumOutputStreams];
   unsigned OutputPositionMask[DXIL::kNumOutputStreams];
   std::vector<unsigned> outputCols;
-  std::vector<unsigned> patchConstCols;
+  std::vector<unsigned> patchConstOrPrimCols;
   bool m_bCoverageIn, m_bInnerCoverageIn;
   bool hasViewID;
   unsigned domainLocSize;
@@ -368,8 +385,8 @@ struct EntryStatus {
     }
 
     outputCols.resize(entryProps.sig.OutputSignature.GetElements().size(), 0);
-    patchConstCols.resize(
-        entryProps.sig.PatchConstantSignature.GetElements().size(), 0);
+    patchConstOrPrimCols.resize(
+        entryProps.sig.PatchConstOrPrimSignature.GetElements().size(), 0);
   }
 };
 
@@ -781,7 +798,7 @@ static bool ValidateOpcodeInProfile(DXIL::OpCode opcode,
   // Instructions: ThreadId=93, GroupId=94, ThreadIdInGroup=95,
   // FlattenedThreadIdInGroup=96
   if ((93 <= op && op <= 96))
-    return (SK == DXIL::ShaderKind::Compute);
+    return (SK == DXIL::ShaderKind::Compute || SK == DXIL::ShaderKind::Mesh || SK == DXIL::ShaderKind::Amplification);
   // Instructions: DomainLocation=105
   if (op == 105)
     return (SK == DXIL::ShaderKind::Domain);
@@ -815,7 +832,7 @@ static bool ValidateOpcodeInProfile(DXIL::OpCode opcode,
   // Instructions: ViewID=138
   if (op == 138)
     return (major > 6 || (major == 6 && minor >= 1))
-        && (SK == DXIL::ShaderKind::Vertex || SK == DXIL::ShaderKind::Hull || SK == DXIL::ShaderKind::Domain || SK == DXIL::ShaderKind::Geometry || SK == DXIL::ShaderKind::Pixel);
+        && (SK == DXIL::ShaderKind::Vertex || SK == DXIL::ShaderKind::Hull || SK == DXIL::ShaderKind::Domain || SK == DXIL::ShaderKind::Geometry || SK == DXIL::ShaderKind::Pixel || SK == DXIL::ShaderKind::Mesh);
   // Instructions: RawBufferLoad=139, RawBufferStore=140
   if ((139 <= op && op <= 140))
     return (major > 6 || (major == 6 && minor >= 2));
@@ -860,6 +877,15 @@ static bool ValidateOpcodeInProfile(DXIL::OpCode opcode,
   // WaveMultiPrefixBitCount=167
   if ((165 <= op && op <= 167))
     return (major > 6 || (major == 6 && minor >= 5));
+  // Instructions: DispatchMesh=173
+  if (op == 173)
+    return (major > 6 || (major == 6 && minor >= 5))
+        && (SK == DXIL::ShaderKind::Amplification);
+  // Instructions: SetMeshOutputCounts=168, EmitIndices=169, GetMeshPayload=170,
+  // StoreVertexOutput=171, StorePrimitiveOutput=172
+  if ((168 <= op && op <= 172))
+    return (major > 6 || (major == 6 && minor >= 5))
+        && (SK == DXIL::ShaderKind::Mesh);
   return true;
   // VALOPCODESM-TEXT:END
 }
@@ -889,8 +915,8 @@ static unsigned ValidateSignatureRowCol(Instruction *I,
   } else {
     if (SE.IsOutput())
       Status.outputCols[SE.GetID()] |= 1 << col;
-    if (SE.IsPatchConstant())
-      Status.patchConstCols[SE.GetID()] |= 1 << col;
+    if (SE.IsPatchConstOrPrim())
+      Status.patchConstOrPrimCols[SE.GetID()] |= 1 << col;
   }
 
   return col;
@@ -1463,10 +1489,13 @@ static void ValidateSignatureDxilOp(CallInst *CI, DXIL::OpCode opcode,
       }
     }
   } break;
-  case DXIL::OpCode::StoreOutput: {
+  case DXIL::OpCode::StoreOutput:
+  case DXIL::OpCode::StoreVertexOutput: 
+  case DXIL::OpCode::StorePrimitiveOutput: {
     Value *outputID =
         CI->getArgOperand(DXIL::OperandIndex::kStoreOutputIDOpIdx);
-    DxilSignature &outputSig = S.OutputSignature;
+    DxilSignature &outputSig = opcode == DXIL::OpCode::StorePrimitiveOutput ?
+      S.PatchConstOrPrimSignature : S.OutputSignature;
     Value *row = CI->getArgOperand(DXIL::OperandIndex::kStoreOutputRowOpIdx);
     Value *col = CI->getArgOperand(DXIL::OperandIndex::kStoreOutputColOpIdx);
     ValidateSignatureAccess(CI, outputSig, outputID, row, col, Status, ValCtx);
@@ -1509,7 +1538,7 @@ static void ValidateSignatureDxilOp(CallInst *CI, DXIL::OpCode opcode,
         DxilEntrySignature &S = EntryProps.sig;
         Value *outputID =
             CI->getArgOperand(DXIL::OperandIndex::kStoreOutputIDOpIdx);
-        DxilSignature &outputSig = S.PatchConstantSignature;
+        DxilSignature &outputSig = S.PatchConstOrPrimSignature;
         Value *row =
             CI->getArgOperand(DXIL::OperandIndex::kStoreOutputRowOpIdx);
         Value *col =
@@ -1615,6 +1644,30 @@ static void ValidateSignatureDxilOp(CallInst *CI, DXIL::OpCode opcode,
                                   {"Emit/CutStream", "Geometry shader"});
     }
   } break;
+  case DXIL::OpCode::EmitIndices: {
+    if (!props.IsMS()) {
+      ValCtx.EmitInstrFormatError(CI, ValidationRule::SmOpcodeInInvalidFunction,
+                                  {"EmitIndices", "Mesh shader"});
+    }
+  } break;
+  case DXIL::OpCode::SetMeshOutputCounts: {
+    if (!props.IsMS()) {
+      ValCtx.EmitInstrFormatError(CI, ValidationRule::SmOpcodeInInvalidFunction,
+                                  {"SetMeshOutputCounts", "Mesh shader"});
+    }
+  } break;
+  case DXIL::OpCode::GetMeshPayload: {
+    if (!props.IsMS()) {
+      ValCtx.EmitInstrFormatError(CI, ValidationRule::SmOpcodeInInvalidFunction,
+                                  {"GetMeshPayload", "Mesh shader"});
+    }
+  } break;
+  case DXIL::OpCode::DispatchMesh: {
+    if (!props.IsAS()) {
+      ValCtx.EmitInstrFormatError(CI, ValidationRule::SmOpcodeInInvalidFunction,
+                                  {"DispatchMesh", "Amplification shader"});
+    }
+  } break;
   default:
     break;
   }
@@ -2284,6 +2337,8 @@ static void ValidateDxilOperationCallInProfile(CallInst *CI,
   case DXIL::OpCode::LoadInput:
   case DXIL::OpCode::DomainLocation:
   case DXIL::OpCode::StoreOutput:
+  case DXIL::OpCode::StoreVertexOutput:
+  case DXIL::OpCode::StorePrimitiveOutput:
   case DXIL::OpCode::OutputControlPointID:
   case DXIL::OpCode::LoadOutputControlPoint:
   case DXIL::OpCode::StorePatchConstant:
@@ -2700,6 +2755,110 @@ static void ValidateGradientOps(Function *F, ArrayRef<CallInst *> ops, ArrayRef<
   }
 }
 
+static void ValidateMsIntrinsics(Function *F,
+                                 ValidationContext &ValCtx,
+                                 CallInst *setMeshOutputCounts,
+                                 CallInst *getMeshPayload) {
+  if (ValCtx.DxilMod.HasDxilFunctionProps(F)) {
+    DXIL::ShaderKind shaderKind = ValCtx.DxilMod.GetDxilFunctionProps(F).shaderKind;
+    if (shaderKind != DXIL::ShaderKind::Mesh)
+      return;
+  } else {
+    return;
+  }
+
+  DominatorTreeAnalysis DTA;
+  DominatorTree DT = DTA.run(*F);
+
+  for (auto b = F->begin(), bend = F->end(); b != bend; ++b) {
+    bool foundSetMeshOutputCountsInCurrentBB = false;
+    for (auto i = b->begin(), iend = b->end(); i != iend; ++i) {
+      llvm::Instruction &I = *i;
+
+      // Calls to external functions.
+      CallInst *CI = dyn_cast<CallInst>(&I);
+      if (CI) {
+        Function *FCalled = CI->getCalledFunction();
+        if (FCalled->isDeclaration()) {
+          // External function validation will diagnose.
+          if (!IsDxilFunction(FCalled)) {
+            continue;
+          }
+
+          if (CI == setMeshOutputCounts) {
+            foundSetMeshOutputCountsInCurrentBB = true;
+          }
+          Value *opcodeVal = CI->getOperand(0);
+          ConstantInt *OpcodeConst = dyn_cast<ConstantInt>(opcodeVal);
+          unsigned opcode = OpcodeConst->getLimitedValue();
+          DXIL::OpCode dxilOpcode = (DXIL::OpCode)opcode;
+
+          if (dxilOpcode == DXIL::OpCode::StoreVertexOutput ||
+              dxilOpcode == DXIL::OpCode::StorePrimitiveOutput ||
+              dxilOpcode == DXIL::OpCode::EmitIndices) {
+            if (setMeshOutputCounts == nullptr) {
+              ValCtx.EmitInstrError(&I, ValidationRule::InstrMissingSetMeshOutputCounts);
+            } else if (!foundSetMeshOutputCountsInCurrentBB &&
+                       !DT.dominates(setMeshOutputCounts->getParent(), I.getParent())) {
+              ValCtx.EmitInstrError(&I, ValidationRule::InstrNonDominatingSetMeshOutputCounts);
+            }
+          }
+        }
+      }
+    }
+  }
+
+  if (getMeshPayload) {
+    PointerType *payloadPTy = cast<PointerType>(getMeshPayload->getType());
+    StructType *payloadTy = cast<StructType>(payloadPTy->getPointerElementType());
+    const DataLayout &DL = F->getParent()->getDataLayout();
+    unsigned payloadSize = DL.getTypeAllocSize(payloadTy);
+
+    if (payloadSize > DXIL::kMaxMSASPayloadSize) {
+      ValCtx.EmitFormatError(ValidationRule::SmMeshShaderPayloadSize,
+        { F->getName(), std::to_string(DXIL::kMaxMSASPayloadSize) });
+    }
+
+    DxilFunctionProps &prop = ValCtx.DxilMod.GetDxilFunctionProps(F);
+    prop.ShaderProps.MS.payloadByteSize = payloadSize;
+  }
+}
+
+static void ValidateAsIntrinsics(Function *F, ValidationContext &ValCtx, CallInst *dispatchMesh) {
+  if (ValCtx.DxilMod.HasDxilFunctionProps(F)) {
+    DXIL::ShaderKind shaderKind = ValCtx.DxilMod.GetDxilFunctionProps(F).shaderKind;
+    if (shaderKind != DXIL::ShaderKind::Amplification)
+      return;
+  }
+  else {
+    return;
+  }
+
+  if (dispatchMesh == nullptr) {
+    ValCtx.EmitError(ValidationRule::InstrNotOnceDispatchMesh);
+    return;
+  }
+
+  PostDominatorTree PDT;
+  PDT.runOnFunction(*F);
+
+  if (!PDT.dominates(dispatchMesh->getParent(), &F->getEntryBlock())) {
+    ValCtx.EmitInstrError(dispatchMesh, ValidationRule::InstrNonDominatingDispatchMesh);
+  }
+
+  Function *dispatchMeshFunc = dispatchMesh->getCalledFunction();
+  FunctionType *dispatchMeshFuncTy = dispatchMeshFunc->getFunctionType();
+  PointerType *payloadPTy = cast<PointerType>(dispatchMeshFuncTy->getParamType(4));
+  StructType *payloadTy = cast<StructType>(payloadPTy->getPointerElementType());
+  const DataLayout &DL = F->getParent()->getDataLayout();
+  unsigned payloadSize = DL.getTypeAllocSize(payloadTy);
+
+  if (payloadSize > DXIL::kMaxMSASPayloadSize) {
+    ValCtx.EmitFormatError(ValidationRule::SmAmplificationShaderPayloadSize,
+      { F->getName(), std::to_string(DXIL::kMaxMSASPayloadSize) });
+  }
+}
+
 static void ValidateControlFlowHint(BasicBlock &bb, ValidationContext &ValCtx) {
   // Validate controlflow hint.
   TerminatorInst *TI = bb.getTerminator();
@@ -2957,11 +3116,49 @@ static bool IsLLVMInstructionAllowedForLib(Instruction &I, ValidationContext &Va
   }
 }
 
+static bool IsFromMeshPayload(Instruction *I) {
+  unsigned opcode = I->getOpcode();
+  switch (opcode) {
+  case Instruction::Alloca: {
+    break;
+  }
+  case Instruction::GetElementPtr: {
+    Value *src0 = I->getOperand(0);
+    if (I = dyn_cast<Instruction>(src0)) {
+      return IsFromMeshPayload(I);
+    }
+    return false;
+  }
+  case Instruction::Store: {
+    Value *src1 = I->getOperand(1);
+    if (I = dyn_cast<Instruction>(src1)) {
+      return IsFromMeshPayload(I);
+    }
+    return false;
+  }
+  default:
+    return false;
+  }
+
+  for (auto user : I->users()) {
+    if (CallInst *CI = dyn_cast<CallInst>(user)) {
+      Function *func = CI->getCalledFunction();
+      StringRef funcName = func->getName();
+      if (funcName.startswith("dx.op.dispatchMesh"))
+        return true;
+    }
+  }
+  return false;
+}
+
 static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
   bool SupportsMinPrecision =
       ValCtx.DxilMod.GetGlobalFlags() & DXIL::kEnableMinPrecision;
   SmallVector<CallInst *, 16> gradientOps;
   SmallVector<CallInst *, 16> barriers;
+  CallInst *setMeshOutputCounts = nullptr;
+  CallInst *getMeshPayload = nullptr;
+  CallInst *dispatchMesh = nullptr;
   for (auto b = F->begin(), bend = F->end(); b != bend; ++b) {
     for (auto i = b->begin(), iend = b->end(); i != iend; ++i) {
       llvm::Instruction &I = *i;
@@ -3023,6 +3220,30 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
           // External function validation will check the parameter
           // list. This function will check that the call does not
           // violate any rules.
+
+          if (dxilOpcode == DXIL::OpCode::SetMeshOutputCounts) {
+            // validate the call count of SetMeshOutputCounts
+            if (setMeshOutputCounts != nullptr) {
+              ValCtx.EmitInstrError(&I, ValidationRule::InstrMultipleSetMeshOutputCounts);
+            }
+            setMeshOutputCounts = CI;
+          }
+
+          if (dxilOpcode == DXIL::OpCode::GetMeshPayload) {
+            // validate the call count of GetMeshPayload
+            if (getMeshPayload != nullptr) {
+              ValCtx.EmitInstrError(&I, ValidationRule::InstrMultipleGetMeshPayload);
+            }
+            getMeshPayload = CI;
+          }
+
+          if (dxilOpcode == DXIL::OpCode::DispatchMesh) {
+            // validate the call count of DispatchMesh
+            if (dispatchMesh != nullptr) {
+              ValCtx.EmitInstrError(&I, ValidationRule::InstrNotOnceDispatchMesh);
+            }
+            dispatchMesh = CI;
+          }
         }
         continue;
       }
@@ -3069,11 +3290,13 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
       unsigned opcode = I.getOpcode();
       switch (opcode) {
       case Instruction::Alloca: {
-        AllocaInst *AI = cast<AllocaInst>(&I);
-        // TODO: validate address space and alignment
-        Type *Ty = AI->getAllocatedType();
-        if (!ValidateType(Ty, ValCtx)) {
-          continue;
+        if (!IsFromMeshPayload(&I)) {
+          AllocaInst *AI = cast<AllocaInst>(&I);
+          // TODO: validate address space and alignment
+          Type *Ty = AI->getAllocatedType();
+          if (!ValidateType(Ty, ValCtx)) {
+            continue;
+          }
         }
       } break;
       case Instruction::ExtractValue: {
@@ -3096,16 +3319,20 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
         }
       } break;
       case Instruction::Store: {
-        StoreInst *SI = cast<StoreInst>(&I);
-        Type *Ty = SI->getValueOperand()->getType();
-        if (!ValidateType(Ty, ValCtx)) {
-          continue;
+        if (!IsFromMeshPayload(&I)) {
+          StoreInst *SI = cast<StoreInst>(&I);
+          Type *Ty = SI->getValueOperand()->getType();
+          if (!ValidateType(Ty, ValCtx)) {
+            continue;
+          }
         }
       } break;
       case Instruction::GetElementPtr: {
-        Type *Ty = I.getType()->getPointerElementType();
-        if (!ValidateType(Ty, ValCtx)) {
-          continue;
+        if (!IsFromMeshPayload(&I)) {
+          Type *Ty = I.getType()->getPointerElementType();
+          if (!ValidateType(Ty, ValCtx)) {
+            continue;
+          }
         }
         GetElementPtrInst *GEP = cast<GetElementPtrInst>(&I);
         bool allImmIndex = true;
@@ -3220,6 +3447,10 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
   if (!gradientOps.empty()) {
     ValidateGradientOps(F, gradientOps, barriers, ValCtx);
   }
+
+  ValidateMsIntrinsics(F, ValCtx, setMeshOutputCounts, getMeshPayload);
+
+  ValidateAsIntrinsics(F, ValCtx, dispatchMesh);
 }
 
 static void ValidateFunction(Function &F, ValidationContext &ValCtx) {
@@ -3458,10 +3689,16 @@ static void ValidateGlobalVariables(ValidationContext &ValCtx) {
     }
   }
 
-  if (TGSMSize > DXIL::kMaxTGSMSize) {
+  if (M.GetShaderModel()->IsMS()) {
+    if (TGSMSize > DXIL::kMaxMSSMSize) {
+      ValCtx.EmitFormatError(ValidationRule::SmMaxMSSMSize,
+                             { std::to_string(TGSMSize),
+                               std::to_string(DXIL::kMaxMSSMSize) });
+    }
+  } else if (TGSMSize > DXIL::kMaxTGSMSize) {
     ValCtx.EmitFormatError(ValidationRule::SmMaxTGSMSize,
-                           {std::to_string(TGSMSize),
-                            std::to_string(DXIL::kMaxTGSMSize)});
+                           { std::to_string(TGSMSize),
+                             std::to_string(DXIL::kMaxTGSMSize) });
   }
   if (!fixAddrTGSMList.empty()) {
     ValidateTGSMRaceCondition(fixAddrTGSMList, ValCtx);
@@ -4408,6 +4645,14 @@ static void ValidateNoInterpModeSignature(ValidationContext &ValCtx, const DxilS
   }
 }
 
+static void ValidateConstantInterpModeSignature(ValidationContext &ValCtx, const DxilSignature &S) {
+  for (auto &E : S.GetElements()) {
+    if (!E->GetInterpolationMode()->IsConstant()) {
+      ValCtx.EmitSignatureError(E.get(), ValidationRule::SmConstantInterpMode);
+    }
+  }
+}
+
 static void ValidateEntrySignatures(ValidationContext &ValCtx,
                                     const DxilEntryProps &entryProps,
                                     EntryStatus &Status,
@@ -4419,7 +4664,7 @@ static void ValidateEntrySignatures(ValidationContext &ValCtx,
     // No signatures allowed
     if (!S.InputSignature.GetElements().empty() ||
         !S.OutputSignature.GetElements().empty() ||
-        !S.PatchConstantSignature.GetElements().empty()) {
+        !S.PatchConstOrPrimSignature.GetElements().empty()) {
       ValCtx.EmitFormatError(ValidationRule::SmRayShaderSignatures, { F.getName() });
     }
 
@@ -4465,6 +4710,7 @@ static void ValidateEntrySignatures(ValidationContext &ValCtx,
   bool isVS = props.IsVS();
   bool isGS = props.IsGS();
   bool isCS = props.IsCS();
+  bool isMS = props.IsMS();
 
   if (isPS) {
     // PS output no interp mode.
@@ -4473,8 +4719,14 @@ static void ValidateEntrySignatures(ValidationContext &ValCtx,
     // VS input no interp mode.
     ValidateNoInterpModeSignature(ValCtx, S.InputSignature);
   }
-  // patch constant no interp mode.
-  ValidateNoInterpModeSignature(ValCtx, S.PatchConstantSignature);
+
+  if (isMS) {
+    // primitive output constant interp mode.
+    ValidateConstantInterpModeSignature(ValCtx, S.PatchConstOrPrimSignature);
+  } else {
+    // patch constant no interp mode.
+    ValidateNoInterpModeSignature(ValCtx, S.PatchConstOrPrimSignature);
+  }
 
   unsigned maxInputScalars = DXIL::kMaxInputTotalScalars;
   unsigned maxOutputScalars = 0;
@@ -4493,13 +4745,18 @@ static void ValidateEntrySignatures(ValidationContext &ValCtx,
       maxOutputScalars = DXIL::kMaxOutputTotalScalars;
       maxPatchConstantScalars = DXIL::kMaxHSOutputPatchConstantTotalScalars;
     break;
+  case DXIL::ShaderKind::Mesh:
+    maxOutputScalars = DXIL::kMaxOutputTotalScalars;
+    maxPatchConstantScalars = DXIL::kMaxOutputTotalScalars;
+    break;
+  case DXIL::ShaderKind::Amplification:
   default:
     break;
   }
 
   ValidateSignature(ValCtx, S.InputSignature, Status, maxInputScalars);
   ValidateSignature(ValCtx, S.OutputSignature, Status, maxOutputScalars);
-  ValidateSignature(ValCtx, S.PatchConstantSignature, Status,
+  ValidateSignature(ValCtx, S.PatchConstOrPrimSignature, Status,
                     maxPatchConstantScalars);
 
   if (isPS) {
@@ -4579,10 +4836,53 @@ static void ValidateEntrySignatures(ValidationContext &ValCtx,
   if (isCS) {
       if (!S.InputSignature.GetElements().empty() ||
           !S.OutputSignature.GetElements().empty() ||
-          !S.PatchConstantSignature.GetElements().empty()) {
+          !S.PatchConstOrPrimSignature.GetElements().empty()) {
         ValCtx.EmitError(ValidationRule::SmCSNoSignatures);
       }
   }
+
+  if (isMS) {
+    unsigned VertexSignatureRows = S.OutputSignature.GetRowCount();
+    if (VertexSignatureRows > DXIL::kMaxMSVSigRows) {
+      ValCtx.EmitFormatError(
+        ValidationRule::SmMeshVSigRowCount,
+        { F.getName(), std::to_string(DXIL::kMaxMSVSigRows) });
+    }
+    unsigned PrimitiveSignatureRows = S.PatchConstOrPrimSignature.GetRowCount();
+    if (PrimitiveSignatureRows > DXIL::kMaxMSPSigRows) {
+      ValCtx.EmitFormatError(
+        ValidationRule::SmMeshPSigRowCount,
+        { F.getName(), std::to_string(DXIL::kMaxMSPSigRows) });
+    }
+    if (VertexSignatureRows + PrimitiveSignatureRows > DXIL::kMaxMSTotalSigRows) {
+      ValCtx.EmitFormatError(
+        ValidationRule::SmMeshTotalSigRowCount,
+        { F.getName(), std::to_string(DXIL::kMaxMSTotalSigRows) });
+    }
+
+    unsigned maxVertexCount = props.ShaderProps.MS.maxVertexCount;
+    unsigned maxPrimitiveCount = props.ShaderProps.MS.maxPrimitiveCount;
+    unsigned totalOutputScalars = 0;
+    for (auto &SE : S.OutputSignature.GetElements()) {
+      totalOutputScalars += SE->GetRows() * SE->GetCols() * maxVertexCount;
+    }
+    for (auto &SE : S.PatchConstOrPrimSignature.GetElements()) {
+      totalOutputScalars += SE->GetRows() * SE->GetCols() * maxPrimitiveCount;
+    }
+
+    if (totalOutputScalars > DXIL::kMaxMSOutputTotalScalars) {
+      ValCtx.EmitFormatError(
+        ValidationRule::SmMeshShaderOutputSize,
+        { F.getName(), std::to_string(DXIL::kMaxMSOutputTotalScalars) });
+    }
+
+    unsigned totalInputOutputScalars = totalOutputScalars + props.ShaderProps.MS.payloadByteSize;
+    if (totalInputOutputScalars > DXIL::kMaxMSInputOutputTotalScalars) {
+      ValCtx.EmitFormatError(
+        ValidationRule::SmMeshShaderInOutSize,
+        { F.getName(), std::to_string(DXIL::kMaxMSInputOutputTotalScalars) });
+    }
+  }
 }
 
 static void ValidateEntrySignatures(ValidationContext &ValCtx) {
@@ -4617,7 +4917,7 @@ static void CheckPatchConstantSemantic(ValidationContext &ValCtx,
   DXIL::TessellatorDomain domain =
       isHS ? props.ShaderProps.HS.domain : props.ShaderProps.DS.domain;
 
-  const DxilSignature &patchConstantSig = EntryProps.sig.PatchConstantSignature;
+  const DxilSignature &patchConstantSig = EntryProps.sig.PatchConstOrPrimSignature;
 
   const unsigned kQuadEdgeSize = 4;
   const unsigned kQuadInsideSize = 2;
@@ -4765,6 +5065,92 @@ static void ValidateEntryProps(ValidationContext &ValCtx,
                               std::to_string(DXIL::kMaxCSThreadsPerGroup)});
     }
 
+    // type of threadID, thread group ID take care by DXIL operation overload
+    // check.
+  } else if (ShaderType == DXIL::ShaderKind::Mesh) {
+    const auto &MS = props.ShaderProps.MS;
+    unsigned x = MS.numThreads[0];
+    unsigned y = MS.numThreads[1];
+    unsigned z = MS.numThreads[2];
+
+    unsigned threadsInGroup = x * y * z;
+
+    if ((x < DXIL::kMinMSASThreadGroupX) || (x > DXIL::kMaxMSASThreadGroupX)) {
+      ValCtx.EmitFormatError(ValidationRule::SmThreadGroupChannelRange,
+                             {"X", std::to_string(x),
+                              std::to_string(DXIL::kMinMSASThreadGroupX),
+                              std::to_string(DXIL::kMaxMSASThreadGroupX)});
+    }
+    if ((y < DXIL::kMinMSASThreadGroupY) || (y > DXIL::kMaxMSASThreadGroupY)) {
+      ValCtx.EmitFormatError(ValidationRule::SmThreadGroupChannelRange,
+                             {"Y", std::to_string(y),
+                              std::to_string(DXIL::kMinMSASThreadGroupY),
+                              std::to_string(DXIL::kMaxMSASThreadGroupY)});
+    }
+    if ((z < DXIL::kMinMSASThreadGroupZ) || (z > DXIL::kMaxMSASThreadGroupZ)) {
+      ValCtx.EmitFormatError(ValidationRule::SmThreadGroupChannelRange,
+                             {"Z", std::to_string(z),
+                              std::to_string(DXIL::kMinMSASThreadGroupZ),
+                              std::to_string(DXIL::kMaxMSASThreadGroupZ)});
+    }
+
+    if (threadsInGroup > DXIL::kMaxMSASThreadsPerGroup) {
+      ValCtx.EmitFormatError(ValidationRule::SmMaxTheadGroup,
+                             {std::to_string(threadsInGroup),
+                              std::to_string(DXIL::kMaxMSASThreadsPerGroup)});
+    }
+
+    // type of threadID, thread group ID take care by DXIL operation overload
+    // check.
+
+    unsigned maxVertexCount = MS.maxVertexCount;
+    if (maxVertexCount > DXIL::kMaxMSOutputVertexCount) {
+      ValCtx.EmitFormatError(
+        ValidationRule::SmMeshShaderMaxVertexCount,
+          { std::to_string(DXIL::kMaxMSOutputVertexCount),
+            std::to_string(maxVertexCount) });
+    }
+
+    unsigned maxPrimitiveCount = MS.maxPrimitiveCount;
+    if (maxPrimitiveCount > DXIL::kMaxMSOutputPrimitiveCount) {
+      ValCtx.EmitFormatError(
+        ValidationRule::SmMeshShaderMaxPrimitiveCount,
+          { std::to_string(DXIL::kMaxMSOutputPrimitiveCount),
+            std::to_string(maxPrimitiveCount) });
+    }
+  } else if (ShaderType == DXIL::ShaderKind::Amplification) {
+    const auto &AS = props.ShaderProps.AS;
+    unsigned x = AS.numThreads[0];
+    unsigned y = AS.numThreads[1];
+    unsigned z = AS.numThreads[2];
+
+    unsigned threadsInGroup = x * y * z;
+
+    if ((x < DXIL::kMinMSASThreadGroupX) || (x > DXIL::kMaxMSASThreadGroupX)) {
+      ValCtx.EmitFormatError(ValidationRule::SmThreadGroupChannelRange,
+                             {"X", std::to_string(x),
+                              std::to_string(DXIL::kMinMSASThreadGroupX),
+                              std::to_string(DXIL::kMaxMSASThreadGroupX)});
+    }
+    if ((y < DXIL::kMinMSASThreadGroupY) || (y > DXIL::kMaxMSASThreadGroupY)) {
+      ValCtx.EmitFormatError(ValidationRule::SmThreadGroupChannelRange,
+                             {"Y", std::to_string(y),
+                              std::to_string(DXIL::kMinMSASThreadGroupY),
+                              std::to_string(DXIL::kMaxMSASThreadGroupY)});
+    }
+    if ((z < DXIL::kMinMSASThreadGroupZ) || (z > DXIL::kMaxMSASThreadGroupZ)) {
+      ValCtx.EmitFormatError(ValidationRule::SmThreadGroupChannelRange,
+                             {"Z", std::to_string(z),
+                              std::to_string(DXIL::kMinMSASThreadGroupZ),
+                              std::to_string(DXIL::kMaxMSASThreadGroupZ)});
+    }
+
+    if (threadsInGroup > DXIL::kMaxMSASThreadsPerGroup) {
+      ValCtx.EmitFormatError(ValidationRule::SmMaxTheadGroup,
+                             {std::to_string(threadsInGroup),
+                              std::to_string(DXIL::kMaxMSASThreadsPerGroup)});
+    }
+
     // type of threadID, thread group ID take care by DXIL operation overload
     // check.
   } else if (ShaderType == DXIL::ShaderKind::Domain) {
@@ -5023,10 +5409,10 @@ static void ValidateUninitializedOutput(ValidationContext &ValCtx,
   const DxilFunctionProps &props = entryProps.props;
   // For HS only need to check Tessfactor which is in patch constant sig.
   if (props.IsHS()) {
-    std::vector<unsigned> &patchConstCols = Status.patchConstCols;
-    const DxilSignature &patchConstSig = entryProps.sig.PatchConstantSignature;
+    std::vector<unsigned> &patchConstOrPrimCols = Status.patchConstOrPrimCols;
+    const DxilSignature &patchConstSig = entryProps.sig.PatchConstOrPrimSignature;
     for (auto &E : patchConstSig.GetElements()) {
-      unsigned mask = patchConstCols[E->GetID()];
+      unsigned mask = patchConstOrPrimCols[E->GetID()];
       unsigned requireMask = (1 << E->GetCols()) - 1;
       // TODO: check other case uninitialized output is allowed.
       if (mask != requireMask && !E->GetSemantic()->IsArbitrary()) {
@@ -5214,8 +5600,8 @@ static void VerifySignatureMatches(_In_ ValidationContext &ValCtx,
   case hlsl::DXIL::SignatureKind::Output:
     pName = "Program Output Signature";
     break;
-  case hlsl::DXIL::SignatureKind::PatchConstant:
-    pName = "Program Patch Constant Signature";
+  case hlsl::DXIL::SignatureKind::PatchConstOrPrim:
+    pName = "Program Patch Constant or Primitive Signature";
     break;
   default:
     break;
@@ -5360,7 +5746,9 @@ HRESULT ValidateDxilContainerParts(llvm::Module *pModule,
   ValidationContext ValCtx(*pModule, pDebugModule, *pDxilModule, DiagPrinter);
 
   DXIL::ShaderKind ShaderKind = pDxilModule->GetShaderModel()->GetKind();
-  bool bTess = ShaderKind == DXIL::ShaderKind::Hull || ShaderKind == DXIL::ShaderKind::Domain;
+  bool bTessOrMesh = ShaderKind == DXIL::ShaderKind::Hull ||
+                     ShaderKind == DXIL::ShaderKind::Domain ||
+                     ShaderKind == DXIL::ShaderKind::Mesh;
 
   std::unordered_set<uint32_t> FourCCFound;
   const DxilPartHeader *pRootSignaturePart = nullptr;
@@ -5398,8 +5786,8 @@ HRESULT ValidateDxilContainerParts(llvm::Module *pModule,
       if (ValCtx.isLibProfile) {
         ValCtx.EmitFormatError(ValidationRule::ContainerPartInvalid, { szFourCC });
       } else {
-        if (bTess) {
-          VerifySignatureMatches(ValCtx, DXIL::SignatureKind::PatchConstant, GetDxilPartData(pPart), pPart->PartSize);
+        if (bTessOrMesh) {
+          VerifySignatureMatches(ValCtx, DXIL::SignatureKind::PatchConstOrPrim, GetDxilPartData(pPart), pPart->PartSize);
         } else {
           ValCtx.EmitFormatError(ValidationRule::ContainerPartMatches, {"Program Patch Constant Signature"});
         }
@@ -5466,8 +5854,8 @@ HRESULT ValidateDxilContainerParts(llvm::Module *pModule,
     if (FourCCFound.find(DFCC_OutputSignature) == FourCCFound.end()) {
       VerifySignatureMatches(ValCtx, DXIL::SignatureKind::Output, nullptr, 0);
     }
-    if (bTess && FourCCFound.find(DFCC_PatchConstantSignature) == FourCCFound.end() &&
-        pDxilModule->GetPatchConstantSignature().GetElements().size())
+    if (bTessOrMesh && FourCCFound.find(DFCC_PatchConstantSignature) == FourCCFound.end() &&
+        pDxilModule->GetPatchConstOrPrimSignature().GetElements().size())
     {
       ValCtx.EmitFormatError(ValidationRule::ContainerPartMissing, { "Program Patch Constant Signature" });
     }

+ 2 - 0
lib/HLSL/HLModule.cpp

@@ -831,6 +831,8 @@ void HLModule::GetParameterRowsAndCols(Type *Ty, unsigned &rows, unsigned &cols,
   bool skipOneLevelArray = inputQual == DxilParamInputQual::InputPatch;
   skipOneLevelArray |= inputQual == DxilParamInputQual::OutputPatch;
   skipOneLevelArray |= inputQual == DxilParamInputQual::InputPrimitive;
+  skipOneLevelArray |= inputQual == DxilParamInputQual::OutVertices;
+  skipOneLevelArray |= inputQual == DxilParamInputQual::OutPrimitives;
 
   if (skipOneLevelArray) {
     if (Ty->isArrayTy())

+ 32 - 0
lib/HLSL/HLOperationLower.cpp

@@ -2368,6 +2368,36 @@ Value *TranslateFaceforward(CallInst *CI, IntrinsicOp IOP, OP::OpCode op,
   Value *faceforward = Builder.CreateSelect(dotLtZero, n, negN);
   return faceforward;
 }
+
+Value *TrivialSetMeshOutputCounts(CallInst *CI, IntrinsicOp IOP, OP::OpCode op,
+  HLOperationLowerHelper &helper, HLObjectOperationLowerHelper *pObjHelper, bool &Translated) {
+  hlsl::OP *hlslOP = &helper.hlslOP;
+  Value *src0 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
+  Value *src1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
+  IRBuilder<> Builder(CI);
+  Constant *opArg = hlslOP->GetU32Const((unsigned)op);
+  Value *args[] = { opArg, src0, src1 };
+  Function *dxilFunc = hlslOP->GetOpFunc(op, Type::getVoidTy(CI->getContext()));
+
+  Builder.CreateCall(dxilFunc, args);
+  return nullptr;
+}
+
+Value *TrivialDispatchMesh(CallInst *CI, IntrinsicOp IOP, OP::OpCode op,
+  HLOperationLowerHelper &helper, HLObjectOperationLowerHelper *pObjHelper, bool &Translated) {
+  hlsl::OP *hlslOP = &helper.hlslOP;
+  Value *src0 = CI->getArgOperand(HLOperandIndex::kDispatchMeshOpThreadX);
+  Value *src1 = CI->getArgOperand(HLOperandIndex::kDispatchMeshOpThreadY);
+  Value *src2 = CI->getArgOperand(HLOperandIndex::kDispatchMeshOpThreadZ);
+  Value *src3 = CI->getArgOperand(HLOperandIndex::kDispatchMeshOpPayload);
+  IRBuilder<> Builder(CI);
+  Constant *opArg = hlslOP->GetU32Const((unsigned)op);
+  Value *args[] = { opArg, src0, src1, src2, src3 };
+  Function *dxilFunc = hlslOP->GetOpFunc(op, src3->getType());
+
+  Builder.CreateCall(dxilFunc, args);
+  return nullptr;
+}
 }
 
 // MOP intrinsics
@@ -4799,6 +4829,7 @@ IntrinsicLower gLowerTable[static_cast<unsigned>(IntrinsicOp::Num_Intrinsics)] =
     {IntrinsicOp::IOP_D3DCOLORtoUBYTE4, TranslateD3DColorToUByte4, DXIL::OpCode::NumOpCodes},
     {IntrinsicOp::IOP_DeviceMemoryBarrier, TrivialBarrier, DXIL::OpCode::Barrier},
     {IntrinsicOp::IOP_DeviceMemoryBarrierWithGroupSync, TrivialBarrier, DXIL::OpCode::Barrier},
+    {IntrinsicOp::IOP_DispatchMesh, TrivialDispatchMesh, DXIL::OpCode::DispatchMesh },
     {IntrinsicOp::IOP_DispatchRaysDimensions, TranslateNoArgVectorOperation, DXIL::OpCode::DispatchRaysDimensions},
     {IntrinsicOp::IOP_DispatchRaysIndex, TranslateNoArgVectorOperation, DXIL::OpCode::DispatchRaysIndex},
     {IntrinsicOp::IOP_EvaluateAttributeAtSample, TranslateEvalSample, DXIL::OpCode::NumOpCodes},
@@ -4847,6 +4878,7 @@ IntrinsicLower gLowerTable[static_cast<unsigned>(IntrinsicOp::Num_Intrinsics)] =
     {IntrinsicOp::IOP_RayTCurrent, TrivialNoArgWithRetOperation, DXIL::OpCode::RayTCurrent},
     {IntrinsicOp::IOP_RayTMin, TrivialNoArgWithRetOperation, DXIL::OpCode::RayTMin},
     {IntrinsicOp::IOP_ReportHit, TranslateReportIntersection, DXIL::OpCode::ReportHit},
+    {IntrinsicOp::IOP_SetMeshOutputCounts, TrivialSetMeshOutputCounts, DXIL::OpCode::SetMeshOutputCounts},
     {IntrinsicOp::IOP_TraceRay, TranslateTraceRay, DXIL::OpCode::TraceRay},
     {IntrinsicOp::IOP_WaveActiveAllEqual, TranslateWaveAllEqual, DXIL::OpCode::WaveActiveAllEqual},
     {IntrinsicOp::IOP_WaveActiveAllTrue, TranslateWaveA2B, DXIL::OpCode::WaveAllTrue},

+ 208 - 62
lib/HLSL/HLSignatureLower.cpp

@@ -129,6 +129,18 @@ void replaceInputOutputWithIntrinsic(DXIL::SemanticKind semKind, Value *GV,
   case Semantic::Kind::ViewID:
     opcode = OP::OpCode::ViewID;
     break;
+  case Semantic::Kind::GroupThreadID:
+    opcode = OP::OpCode::ThreadIdInGroup;
+    break;
+  case Semantic::Kind::GroupID:
+    opcode = OP::OpCode::GroupId;
+    break;
+  case Semantic::Kind::DispatchThreadID:
+    opcode = OP::OpCode::ThreadId;
+    break;
+  case Semantic::Kind::GroupIndex:
+    opcode = OP::OpCode::FlattenedThreadIdInGroup;
+    break;
   default:
     DXASSERT(0, "invalid semantic");
     return;
@@ -138,19 +150,25 @@ void replaceInputOutputWithIntrinsic(DXIL::SemanticKind semKind, Value *GV,
   Constant *OpArg = hlslOP->GetU32Const((unsigned)opcode);
 
   Value *newArg = nullptr;
-  if (semKind == Semantic::Kind::DomainLocation) {
+  if (semKind == Semantic::Kind::DomainLocation ||
+      semKind == Semantic::Kind::GroupThreadID ||
+      semKind == Semantic::Kind::GroupID ||
+      semKind == Semantic::Kind::DispatchThreadID) {
     unsigned vecSize = 1;
     if (Ty->isVectorTy())
       vecSize = Ty->getVectorNumElements();
 
-    newArg = Builder.CreateCall(dxilFunc, {OpArg, hlslOP->GetU8Const(0)});
+    newArg = Builder.CreateCall(dxilFunc, { OpArg,
+      semKind == Semantic::Kind::DomainLocation ? hlslOP->GetU8Const(0) : hlslOP->GetU32Const(0) });
     if (vecSize > 1) {
       Value *result = UndefValue::get(Ty);
       result = Builder.CreateInsertElement(result, newArg, (uint64_t)0);
 
       for (unsigned i = 1; i < vecSize; i++) {
         Value *newElt =
-            Builder.CreateCall(dxilFunc, {OpArg, hlslOP->GetU8Const(i)});
+            Builder.CreateCall(dxilFunc, { OpArg,
+              semKind == Semantic::Kind::DomainLocation ? hlslOP->GetU8Const(i)
+                                                        : hlslOP->GetU32Const(i) });
         result = Builder.CreateInsertElement(result, newElt, i);
       }
       newArg = result;
@@ -223,7 +241,15 @@ void HLSignatureLower::ProcessArgument(Function *func,
       paramAnnotation.GetInterpolationMode().GetKind();
 
   // Set undefined interpMode.
-  if (!sigPoint->NeedsInterpMode())
+  if (sigPoint->GetKind() == DXIL::SigPointKind::MSPOut) {
+    if (interpMode != InterpolationMode::Kind::Undefined &&
+        interpMode != InterpolationMode::Kind::Constant) {
+      Entry->getContext().emitError(
+        "Mesh shader's primitive outputs' interpolation mode must be constant or undefined.");
+    }
+    interpMode = InterpolationMode::Kind::Constant;
+  }
+  else if (!sigPoint->NeedsInterpMode())
     interpMode = InterpolationMode::Kind::Undefined;
   else if (interpMode == InterpolationMode::Kind::Undefined) {
     // Type-based default: linear for floats, constant for others.
@@ -266,7 +292,7 @@ void HLSignatureLower::ProcessArgument(Function *func,
             ? m_InputSemanticsUsed
             : (sigPoint->IsOutput()
                    ? m_OutputSemanticsUsed[streamIdx]
-                   : (sigPoint->IsPatchConstant() ? m_PatchConstantSemanticsUsed
+                   : (sigPoint->IsPatchConstOrPrim() ? m_PatchConstantSemanticsUsed
                                                   : m_OtherSemanticsUsed));
     if (SemanticUseMap.count((unsigned)pSemantic->GetKind()) > 0) {
       auto &SemanticIndexSet = SemanticUseMap[(unsigned)pSemantic->GetKind()];
@@ -346,8 +372,8 @@ void HLSignatureLower::ProcessArgument(Function *func,
   case DXIL::SignatureKind::Output:
     pSig = &EntrySig.OutputSignature;
     break;
-  case DXIL::SignatureKind::PatchConstant:
-    pSig = &EntrySig.PatchConstantSignature;
+  case DXIL::SignatureKind::PatchConstOrPrim:
+    pSig = &EntrySig.PatchConstOrPrimSignature;
     break;
   default:
     DXASSERT(false, "Expected real signature kind at this point");
@@ -359,7 +385,7 @@ void HLSignatureLower::ProcessArgument(Function *func,
   {
     // Add signature element to appropriate maps
     if (isPatchConstantFunction &&
-        sigKind != DXIL::SignatureKind::PatchConstant) {
+        sigKind != DXIL::SignatureKind::PatchConstOrPrim) {
       pSE = FindArgInSignature(arg, paramAnnotation.GetSemanticString(),
                                interpMode, sigPoint->GetKind(), *pSig);
       if (!pSE) {
@@ -415,6 +441,14 @@ void HLSignatureLower::CreateDxilSignatures() {
     if (HLModule::IsStreamOutputPtrType(Ty))
       continue;
 
+    // Skip OutIndices and InPayload
+    DxilParameterAnnotation &paramAnnotation =
+      EntryAnnotation->GetParameterAnnotation(arg.getArgNo());
+    hlsl::DxilParamInputQual qual = paramAnnotation.GetParamInputQual();
+    if (qual == hlsl::DxilParamInputQual::OutIndices ||
+        qual == hlsl::DxilParamInputQual::InPayload)
+      continue;
+
     ProcessArgument(Entry, EntryAnnotation, arg, props, pSM,
                     isPatchConstantFunctionFalse, bForOutFasle, bHasClipPlane);
   }
@@ -463,16 +497,19 @@ void HLSignatureLower::AllocateDxilInputOutputs() {
         "Failed to allocate all input signature elements in available space.");
   }
 
-  hlsl::PackDxilSignature(EntrySig.OutputSignature, packing);
-  if (!EntrySig.OutputSignature.IsFullyAllocated()) {
-    HLM.GetCtx().emitError(
-        "Failed to allocate all output signature elements in available space.");
+  if (props.shaderKind != DXIL::ShaderKind::Amplification) {
+    hlsl::PackDxilSignature(EntrySig.OutputSignature, packing);
+    if (!EntrySig.OutputSignature.IsFullyAllocated()) {
+      HLM.GetCtx().emitError(
+          "Failed to allocate all output signature elements in available space.");
+    }
   }
 
   if (props.shaderKind == DXIL::ShaderKind::Hull ||
-      props.shaderKind == DXIL::ShaderKind::Domain) {
-    hlsl::PackDxilSignature(EntrySig.PatchConstantSignature, packing);
-    if (!EntrySig.PatchConstantSignature.IsFullyAllocated()) {
+      props.shaderKind == DXIL::ShaderKind::Domain ||
+      props.shaderKind == DXIL::ShaderKind::Mesh) {
+    hlsl::PackDxilSignature(EntrySig.PatchConstOrPrimSignature, packing);
+    if (!EntrySig.PatchConstOrPrimSignature.IsFullyAllocated()) {
       HLM.GetCtx().emitError("Failed to allocate all patch constant signature "
                              "elements in available space.");
     }
@@ -494,7 +531,7 @@ void GenerateStOutput(Function *stOutput, MutableArrayRef<Value *> args,
 
 void replaceStWithStOutput(Function *stOutput, StoreInst *stInst,
                            Constant *OpArg, Constant *outputID, Value *idx,
-                           unsigned cols, bool bI1Cast) {
+                           unsigned cols, Value *vertexOrPrimID, bool bI1Cast) {
   IRBuilder<> Builder(stInst);
   Value *val = stInst->getValueOperand();
 
@@ -503,7 +540,9 @@ void replaceStWithStOutput(Function *stOutput, StoreInst *stInst,
     for (unsigned col = 0; col < cols; col++) {
       Value *subVal = Builder.CreateExtractElement(val, col);
       Value *colIdx = Builder.getInt8(col);
-      Value *args[] = {OpArg, outputID, idx, colIdx, subVal};
+      SmallVector<Value *, 4> args = {OpArg, outputID, idx, colIdx, subVal};
+      if (vertexOrPrimID)
+        args.emplace_back(vertexOrPrimID);
       GenerateStOutput(stOutput, args, Builder, bI1Cast);
     }
     // remove stInst
@@ -512,7 +551,9 @@ void replaceStWithStOutput(Function *stOutput, StoreInst *stInst,
     // TODO: support case cols not 1
     DXASSERT(cols == 1, "only support scalar here");
     Value *colIdx = Builder.getInt8(0);
-    Value *args[] = {OpArg, outputID, idx, colIdx, val};
+    SmallVector<Value *, 4> args = {OpArg, outputID, idx, colIdx, val};
+    if (vertexOrPrimID)
+      args.emplace_back(vertexOrPrimID);
     GenerateStOutput(stOutput, args, Builder, bI1Cast);
     // remove stInst
     stInst->eraseFromParent();
@@ -706,22 +747,22 @@ void replaceDirectInputParameter(Value *param, Function *loadInput,
 struct InputOutputAccessInfo {
   // For input output which has only 1 row, idx is 0.
   Value *idx;
-  // VertexID for HS/DS/GS input.
-  Value *vertexID;
+  // VertexID for HS/DS/GS input, MS vertex output. PrimitiveID for MS primitive output
+  Value *vertexOrPrimID;
   // Vector index.
   Value *vectorIdx;
   // Load/Store/LoadMat/StoreMat on input/output.
   Instruction *user;
   InputOutputAccessInfo(Value *index, Instruction *I)
-      : idx(index), vertexID(nullptr), vectorIdx(nullptr), user(I) {}
+      : idx(index), vertexOrPrimID(nullptr), vectorIdx(nullptr), user(I) {}
   InputOutputAccessInfo(Value *index, Instruction *I, Value *ID, Value *vecIdx)
-      : idx(index), vertexID(ID), vectorIdx(vecIdx), user(I) {}
+      : idx(index), vertexOrPrimID(ID), vectorIdx(vecIdx), user(I) {}
 };
 
 void collectInputOutputAccessInfo(
     Value *GV, Constant *constZero,
-    std::vector<InputOutputAccessInfo> &accessInfoList, bool hasVertexID,
-    bool bInput, bool bRowMajor) {
+    std::vector<InputOutputAccessInfo> &accessInfoList, bool hasVertexOrPrimID,
+    bool bInput, bool bRowMajor, bool isMS) {
   // merge GEP use for input output.
   HLModule::MergeGepUse(GV);
   for (auto User = GV->user_begin(); User != GV->user_end();) {
@@ -743,15 +784,15 @@ void collectInputOutputAccessInfo(
       DXASSERT_LOCALVAR(idx, idx->get() == constZero,
                         "only support 0 offset for input pointer");
 
-      Value *vertexID = nullptr;
+      Value *vertexOrPrimID = nullptr;
       Value *vectorIdx = nullptr;
       gep_type_iterator GEPIt = gep_type_begin(GEP), E = gep_type_end(GEP);
 
       // Skip first pointer idx which must be 0.
       GEPIt++;
-      if (hasVertexID) {
-        // Save vertexID.
-        vertexID = GEPIt.getOperand();
+      if (hasVertexOrPrimID) {
+        // Save vertexOrPrimID.
+        vertexOrPrimID = GEPIt.getOperand();
         GEPIt++;
       }
       // Start from first index.
@@ -806,12 +847,12 @@ void collectInputOutputAccessInfo(
         auto GepUserIt = GepUser++;
         if (LoadInst *ldInst = dyn_cast<LoadInst>(*GepUserIt)) {
           if (bInput) {
-            InputOutputAccessInfo info = {idxVal, ldInst, vertexID, vectorIdx};
+            InputOutputAccessInfo info = {idxVal, ldInst, vertexOrPrimID, vectorIdx};
             accessInfoList.push_back(info);
           }
         } else if (StoreInst *stInst = dyn_cast<StoreInst>(*GepUserIt)) {
           if (!bInput) {
-            InputOutputAccessInfo info = {idxVal, stInst, vertexID, vectorIdx};
+            InputOutputAccessInfo info = {idxVal, stInst, vertexOrPrimID, vectorIdx};
             accessInfoList.push_back(info);
           }
         } else if (CallInst *CI = dyn_cast<CallInst>(*GepUserIt)) {
@@ -823,7 +864,7 @@ void collectInputOutputAccessInfo(
                opcode == HLMatLoadStoreOpcode::RowMatLoad)
                   ? bInput
                   : !bInput) {
-            InputOutputAccessInfo info = {idxVal, CI, vertexID, vectorIdx};
+            InputOutputAccessInfo info = {idxVal, CI, vertexOrPrimID, vectorIdx};
             accessInfoList.push_back(info);
           }
         } else {
@@ -842,17 +883,17 @@ void collectInputOutputAccessInfo(
 void GenerateInputOutputUserCall(InputOutputAccessInfo &info, Value *undefVertexIdx,
     Function *ldStFunc, Constant *OpArg, Constant *ID, unsigned cols, bool bI1Cast,
     Constant *columnConsts[],
-    bool bNeedVertexID, bool isArrayTy, bool bInput, bool bIsInout) {
+    bool bNeedVertexOrPrimID, bool isArrayTy, bool bInput, bool bIsInout) {
   Value *idxVal = info.idx;
-  Value *vertexID = undefVertexIdx;
-  if (bNeedVertexID && isArrayTy) {
-    vertexID = info.vertexID;
+  Value *vertexOrPrimID = undefVertexIdx;
+  if (bNeedVertexOrPrimID && isArrayTy) {
+    vertexOrPrimID = info.vertexOrPrimID;
   }
 
   if (LoadInst *ldInst = dyn_cast<LoadInst>(info.user)) {
     SmallVector<Value *, 4> args = {OpArg, ID, idxVal, info.vectorIdx};
-    if (vertexID)
-      args.emplace_back(vertexID);
+    if (vertexOrPrimID)
+      args.emplace_back(vertexOrPrimID);
 
     replaceLdWithLdInput(ldStFunc, ldInst, cols, args, bI1Cast);
   } else if (StoreInst *stInst = dyn_cast<StoreInst>(info.user)) {
@@ -861,7 +902,7 @@ void GenerateInputOutputUserCall(InputOutputAccessInfo &info, Value *undefVertex
     } else {
       if (!info.vectorIdx) {
         replaceStWithStOutput(ldStFunc, stInst, OpArg, ID, idxVal, cols,
-                              bI1Cast);
+                              vertexOrPrimID, bI1Cast);
       } else {
         Value *V = stInst->getValueOperand();
         Type *Ty = V->getType();
@@ -873,7 +914,9 @@ void GenerateInputOutputUserCall(InputOutputAccessInfo &info, Value *undefVertex
           if (ColIdx->getType()->getBitWidth() != 8) {
             ColIdx = Builder.getInt8(ColIdx->getValue().getLimitedValue());
           }
-          Value *args[] = {OpArg, ID, idxVal, ColIdx, V};
+          SmallVector<Value *, 4> args = {OpArg, ID, idxVal, ColIdx, V};
+          if (vertexOrPrimID)
+            args.emplace_back(vertexOrPrimID);
           GenerateStOutput(ldStFunc, args, Builder, bI1Cast);
         } else {
           BasicBlock *BB = stInst->getParent();
@@ -894,7 +937,9 @@ void GenerateInputOutputUserCall(InputOutputAccessInfo &info, Value *undefVertex
 
             ConstantInt *CaseIdx = SwitchBuilder.getInt8(i);
 
-            Value *args[] = {OpArg, ID, idxVal, CaseIdx, V};
+            SmallVector<Value *, 4> args = {OpArg, ID, idxVal, CaseIdx, V};
+            if (vertexOrPrimID)
+              args.emplace_back(vertexOrPrimID);
             GenerateStOutput(ldStFunc, args, CaseBuilder, bI1Cast);
 
             CaseBuilder.CreateBr(EndBB);
@@ -927,8 +972,8 @@ void GenerateInputOutputUserCall(InputOutputAccessInfo &info, Value *undefVertex
           Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
           for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
             SmallVector<Value *, 4> args = { OpArg, ID, rowIdx, columnConsts[r] };
-            if (vertexID)
-              args.emplace_back(vertexID);
+            if (vertexOrPrimID)
+              args.emplace_back(vertexOrPrimID);
 
             Value *input = LocalBuilder.CreateCall(ldStFunc, args);
             unsigned matIdx = MatTy.getColumnMajorIndex(r, c);
@@ -941,8 +986,8 @@ void GenerateInputOutputUserCall(InputOutputAccessInfo &info, Value *undefVertex
           Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
           for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
             SmallVector<Value *, 4> args = { OpArg, ID, rowIdx, columnConsts[c] };
-            if (vertexID)
-              args.emplace_back(vertexID);
+            if (vertexOrPrimID)
+              args.emplace_back(vertexOrPrimID);
 
             Value *input = LocalBuilder.CreateCall(ldStFunc, args);
             unsigned matIdx = MatTy.getRowMajorIndex(r, c);
@@ -1003,20 +1048,39 @@ void GenerateInputOutputUserCall(InputOutputAccessInfo &info, Value *undefVertex
 } // namespace
 
 void HLSignatureLower::GenerateDxilInputs() {
-  GenerateDxilInputsOutputs(/*bInput*/ true);
+  GenerateDxilInputsOutputs(DXIL::SignatureKind::Input);
 }
 
 void HLSignatureLower::GenerateDxilOutputs() {
-  GenerateDxilInputsOutputs(/*bInput*/ false);
+  GenerateDxilInputsOutputs(DXIL::SignatureKind::Output);
 }
 
-void HLSignatureLower::GenerateDxilInputsOutputs(bool bInput) {
+void HLSignatureLower::GenerateDxilPrimOutputs() {
+  GenerateDxilInputsOutputs(DXIL::SignatureKind::PatchConstOrPrim);
+}
+
+void HLSignatureLower::GenerateDxilInputsOutputs(DXIL::SignatureKind SK) {
   OP *hlslOP = HLM.GetOP();
   DxilFunctionProps &props = HLM.GetDxilFunctionProps(Entry);
   Module &M = *(HLM.GetModule());
 
-  OP::OpCode opcode = bInput ? OP::OpCode::LoadInput : OP::OpCode::StoreOutput;
-  bool bNeedVertexID = bInput && (props.IsGS() || props.IsDS() || props.IsHS());
+  OP::OpCode opcode;
+  switch (SK) {
+  case DXIL::SignatureKind::Input:
+    opcode = OP::OpCode::LoadInput;
+    break;
+  case DXIL::SignatureKind::Output:
+    opcode = props.IsMS() ? OP::OpCode::StoreVertexOutput : OP::OpCode::StoreOutput;
+    break;
+  case DXIL::SignatureKind::PatchConstOrPrim:
+    opcode = OP::OpCode::StorePrimitiveOutput;
+    break;
+  default:
+    DXASSERT_NOMSG(0);
+  }
+  bool bInput = SK == DXIL::SignatureKind::Input;
+  bool bNeedVertexOrPrimID = bInput && (props.IsGS() || props.IsDS() || props.IsHS());
+  bNeedVertexOrPrimID |= !bInput && props.IsMS();
 
   Constant *OpArg = hlslOP->GetU32Const((unsigned)opcode);
 
@@ -1030,10 +1094,12 @@ void HLSignatureLower::GenerateDxilInputsOutputs(bool bInput) {
 
   Constant *constZero = hlslOP->GetU32Const(0);
 
-  Value *undefVertexIdx = UndefValue::get(Type::getInt32Ty(HLM.GetCtx()));
+  Value *undefVertexIdx = props.IsMS() || !bInput ? nullptr : UndefValue::get(Type::getInt32Ty(HLM.GetCtx()));
 
   DxilSignature &Sig =
-      bInput ? EntrySig.InputSignature : EntrySig.OutputSignature;
+      bInput ? EntrySig.InputSignature :
+      SK == DXIL::SignatureKind::Output ? EntrySig.OutputSignature :
+      EntrySig.PatchConstOrPrimSignature;
 
   DxilTypeSystem &typeSys = HLM.GetTypeSystem();
   DxilFunctionAnnotation *pFuncAnnot = typeSys.GetFunctionAnnotation(Entry);
@@ -1083,9 +1149,9 @@ void HLSignatureLower::GenerateDxilInputsOutputs(bool bInput) {
 
     if (!GV->getType()->isPointerTy()) {
       DXASSERT(bInput, "direct parameter must be input");
-      Value *vertexID = undefVertexIdx;
+      Value *vertexOrPrimID = undefVertexIdx;
       Value *args[] = {OpArg, ID, /*rowIdx*/ constZero, /*colIdx*/ nullptr,
-                       vertexID};
+                       vertexOrPrimID};
       replaceDirectInputParameter(GV, dxilFunc, cols, args, bI1Cast, hlslOP,
                                   EntryBuilder);
       continue;
@@ -1106,11 +1172,11 @@ void HLSignatureLower::GenerateDxilInputsOutputs(bool bInput) {
     }
     std::vector<InputOutputAccessInfo> accessInfoList;
     collectInputOutputAccessInfo(GV, constZero, accessInfoList,
-                                 bNeedVertexID && bIsArrayTy, bInput, bRowMajor);
+                                 bNeedVertexOrPrimID && bIsArrayTy, bInput, bRowMajor, props.IsMS());
 
     for (InputOutputAccessInfo &info : accessInfoList) {
       GenerateInputOutputUserCall(info, undefVertexIdx, dxilFunc, OpArg, ID,
-                                  cols, bI1Cast, columnConsts, bNeedVertexID,
+                                  cols, bI1Cast, columnConsts, bNeedVertexOrPrimID,
                                   bIsArrayTy, bInput, bIsInout);
     }
   }
@@ -1208,14 +1274,14 @@ void HLSignatureLower::GenerateDxilPatchConstantLdSt() {
   DxilFunctionProps &props = HLM.GetDxilFunctionProps(Entry);
   Module &M = *(HLM.GetModule());
   Constant *constZero = hlslOP->GetU32Const(0);
-  DxilSignature &Sig = EntrySig.PatchConstantSignature;
+  DxilSignature &Sig = EntrySig.PatchConstOrPrimSignature;
   DxilTypeSystem &typeSys = HLM.GetTypeSystem();
   DxilFunctionAnnotation *pFuncAnnot = typeSys.GetFunctionAnnotation(Entry);
   auto InsertPt = Entry->getEntryBlock().getFirstInsertionPt();
   const bool bIsHs = props.IsHS();
   const bool bIsInput = !bIsHs;
   const bool bIsInout = false;
-  const bool bNeedVertexID = false;
+  const bool bNeedVertexOrPrimID = false;
   if (bIsHs) {
     DxilFunctionProps &EntryQual = HLM.GetDxilFunctionProps(Entry);
     Function *patchConstantFunc = EntryQual.ShaderProps.HS.patchConstantFunc;
@@ -1284,8 +1350,8 @@ void HLSignatureLower::GenerateDxilPatchConstantLdSt() {
       }
     }
     std::vector<InputOutputAccessInfo> accessInfoList;
-    collectInputOutputAccessInfo(GV, constZero, accessInfoList, bNeedVertexID,
-                                 bIsInput, bRowMajor);
+    collectInputOutputAccessInfo(GV, constZero, accessInfoList, bNeedVertexOrPrimID,
+                                 bIsInput, bRowMajor, false);
 
     bool bIsArrayTy = GV->getType()->getPointerElementType()->isArrayTy();
     bool isPrecise = m_preciseSigSet.count(SE);
@@ -1294,7 +1360,7 @@ void HLSignatureLower::GenerateDxilPatchConstantLdSt() {
 
     for (InputOutputAccessInfo &info : accessInfoList) {
       GenerateInputOutputUserCall(info, undefVertexIdx, dxilFunc, OpArg, ID,
-                                  cols, bI1Cast, columnConsts, bNeedVertexID,
+                                  cols, bI1Cast, columnConsts, bNeedVertexOrPrimID,
                                   bIsArrayTy, bIsInput, bIsInout);
     }
   }
@@ -1350,12 +1416,12 @@ void HLSignatureLower::GenerateDxilPatchConstantFunctionInputs() {
       }
       std::vector<InputOutputAccessInfo> accessInfoList;
       collectInputOutputAccessInfo(&arg, constZero, accessInfoList,
-                                   /*hasVertexID*/ true, true, bRowMajor);
+                                   /*hasVertexOrPrimID*/ true, true, bRowMajor, false);
       for (InputOutputAccessInfo &info : accessInfoList) {
         if (LoadInst *ldInst = dyn_cast<LoadInst>(info.user)) {
           Constant *OpArg = hlslOP->GetU32Const((unsigned)opcode);
           Value *args[] = {OpArg, inputID, info.idx, info.vectorIdx,
-                           info.vertexID};
+                           info.vertexOrPrimID};
           replaceLdWithLdInput(dxilLdFunc, ldInst, cols, args, bI1Cast);
         } else {
           DXASSERT(0, "input should only be ld");
@@ -1531,10 +1597,87 @@ void HLSignatureLower::GenerateStreamOutputOperations() {
     }
   }
 }
+// Generate DXIL EmitIndices operation.
+void HLSignatureLower::GenerateEmitIndicesOperation(Value *indicesOutput) {
+  OP * hlslOP = HLM.GetOP();
+  Function *DxilFunc = hlslOP->GetOpFunc(OP::OpCode::EmitIndices, Type::getVoidTy(indicesOutput->getContext()));
+  Constant *opArg = hlslOP->GetU32Const((unsigned)OP::OpCode::EmitIndices);
+
+  for (auto U = indicesOutput->user_begin(); U != indicesOutput->user_end();) {
+    Value *user = *(U++);
+    GetElementPtrInst *GEP = cast<GetElementPtrInst>(user);
+    auto idx = GEP->idx_begin();
+    DXASSERT_LOCALVAR(idx, idx->get() == hlslOP->GetU32Const(0),
+                      "only support 0 offset for input pointer");
+    gep_type_iterator GEPIt = gep_type_begin(GEP), E = gep_type_end(GEP);
+
+    // Skip first pointer idx which must be 0.
+    GEPIt++;
+    Value *primIdx = GEPIt.getOperand();
+    DXASSERT(++GEPIt == E, "invalid GEP here");
+
+    auto GepUser = GEP->user_begin();
+    auto GepUserE = GEP->user_end();
+    for (; GepUser != GepUserE;) {
+      auto GepUserIt = GepUser++;
+      StoreInst *stInst = cast<StoreInst>(*GepUserIt);
+      Value *stVal = stInst->getValueOperand();
+      VectorType *VT = cast<VectorType>(stVal->getType());
+      unsigned eleCount = VT->getNumElements();
+      IRBuilder<> Builder(stInst);
+      Value *subVal0 = Builder.CreateExtractElement(stVal, hlslOP->GetU32Const(0));
+      Value *subVal1 = Builder.CreateExtractElement(stVal, hlslOP->GetU32Const(1));
+      Value *subVal2 = eleCount == 3 ?
+        Builder.CreateExtractElement(stVal, hlslOP->GetU32Const(2)) : hlslOP->GetU32Const(0);
+      Value *args[] = { opArg, primIdx, subVal0, subVal1, subVal2 };
+      Builder.CreateCall(DxilFunc, args);
+      stInst->eraseFromParent();
+    }
+    GEP->eraseFromParent();
+  }
+}
+// Generate DXIL EmitIndices operations.
+void HLSignatureLower::GenerateEmitIndicesOperations() {
+  DxilFunctionAnnotation *EntryAnnotation = HLM.GetFunctionAnnotation(Entry);
+  DXASSERT(EntryAnnotation, "must find annotation for entry function");
+
+  for (Argument &arg : Entry->getArgumentList()) {
+    DxilParameterAnnotation &paramAnnotation =
+      EntryAnnotation->GetParameterAnnotation(arg.getArgNo());
+    DxilParamInputQual inputQual = paramAnnotation.GetParamInputQual();
+    if (inputQual == DxilParamInputQual::OutIndices) {
+      GenerateEmitIndicesOperation(&arg);
+    }
+  }
+}
+// Generate DXIL GetMeshPayload operation.
+void HLSignatureLower::GenerateGetMeshPayloadOperation() {
+  DxilFunctionAnnotation *EntryAnnotation = HLM.GetFunctionAnnotation(Entry);
+  DXASSERT(EntryAnnotation, "must find annotation for entry function");
+
+  for (Argument &arg : Entry->getArgumentList()) {
+    DxilParameterAnnotation &paramAnnotation =
+      EntryAnnotation->GetParameterAnnotation(arg.getArgNo());
+    DxilParamInputQual inputQual = paramAnnotation.GetParamInputQual();
+    if (inputQual == DxilParamInputQual::InPayload) {
+      OP * hlslOP = HLM.GetOP();
+      Function *DxilFunc = hlslOP->GetOpFunc(OP::OpCode::GetMeshPayload, arg.getType());
+      Constant *opArg = hlslOP->GetU32Const((unsigned)OP::OpCode::GetMeshPayload);
+      IRBuilder<> Builder(arg.getParent()->getEntryBlock().getFirstInsertionPt());
+      Value *args[] = { opArg };
+      Value *payload = Builder.CreateCall(DxilFunc, args);
+      arg.replaceAllUsesWith(payload);
+    }
+  }
+}
 // Lower signatures.
 void HLSignatureLower::Run() {
   DxilFunctionProps &props = HLM.GetDxilFunctionProps(Entry);
   if (props.IsGraphics()) {
+    if (props.IsMS()) {
+      GenerateEmitIndicesOperations();
+      GenerateGetMeshPayloadOperation();
+    }
     CreateDxilSignatures();
 
     // Allocate input output.
@@ -1542,6 +1685,9 @@ void HLSignatureLower::Run() {
 
     GenerateDxilInputs();
     GenerateDxilOutputs();
+    if (props.IsMS()) {
+      GenerateDxilPrimOutputs();
+    }
   } else if (props.IsCS()) {
     GenerateDxilCSInputs();
   }

+ 9 - 1
lib/HLSL/HLSignatureLower.h

@@ -12,6 +12,7 @@
 #pragma once
 #include <unordered_set>
 #include <unordered_map>
+#include "dxc/DXIL/DxilConstants.h"
 
 namespace llvm {
 class Value;
@@ -49,7 +50,8 @@ private:
   // Generate DXIL input load, output store
   void GenerateDxilInputs();
   void GenerateDxilOutputs();
-  void GenerateDxilInputsOutputs(bool bInput);
+  void GenerateDxilPrimOutputs();
+  void GenerateDxilInputsOutputs(DXIL::SignatureKind SK);
   void GenerateDxilCSInputs();
   void GenerateDxilPatchConstantLdSt();
   void GenerateDxilPatchConstantFunctionInputs();
@@ -59,6 +61,12 @@ private:
   void GenerateStreamOutputOperation(llvm::Value *streamVal, unsigned streamID);
   // Generate DXIL stream output operations.
   void GenerateStreamOutputOperations();
+  // Generate DXIL EmitIndices operation.
+  void GenerateEmitIndicesOperation(llvm::Value *indicesOutput);
+  // Generate DXIL EmitIndices operations.
+  void GenerateEmitIndicesOperations();
+  // Generate DXIL GetMeshPayload operation.
+  void GenerateGetMeshPayloadOperation();
 
 private:
   llvm::Function *Entry;

+ 6 - 2
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -1053,7 +1053,8 @@ void SROA_HLSL::isSafeForScalarRepl(Instruction *I, uint64_t Offset,
         IntrinsicOp opcode = static_cast<IntrinsicOp>(GetHLOpcode(CI));
         if (IntrinsicOp::IOP_TraceRay == opcode ||
             IntrinsicOp::IOP_ReportHit == opcode ||
-            IntrinsicOp::IOP_CallShader == opcode) {
+            IntrinsicOp::IOP_CallShader == opcode ||
+            IntrinsicOp::IOP_DispatchMesh == opcode) {
           return MarkUnsafe(Info, User);
         }
       }
@@ -4826,9 +4827,12 @@ void SROA_Parameter_HLSL::flattenArgument(
     std::vector<Value *> Elts;
 
     // Not flat vector for entry function currently.
-    bool SROAed = SROA_Helper::DoScalarReplacement(
+    bool SROAed = false;
+    if (inputQual != DxilParamInputQual::InPayload) {
+      SROAed = SROA_Helper::DoScalarReplacement(
         V, Elts, Builder, /*bFlatVector*/ false, annotation.IsPrecise(),
         dxilTypeSys, DL, DeadInsts);
+    }
 
     if (SROAed) {
       Type *Ty = V->getType()->getPointerElementType();

+ 24 - 0
tools/clang/include/clang/Basic/Attr.td

@@ -871,6 +871,30 @@ def HLSLExport : InheritableAttr {
   let Documentation = [Undocumented];
 }
 
+def HLSLIndices : InheritableAttr {
+  let Spellings = [CXX11<"", "indices", 2015>];
+  let Subjects = SubjectList<[ParmVar]>;
+  let Documentation = [Undocumented];
+}
+
+def HLSLVertices : InheritableAttr {
+  let Spellings = [CXX11<"", "vertices", 2015>];
+  let Subjects = SubjectList<[ParmVar]>;
+  let Documentation = [Undocumented];
+}
+
+def HLSLPrimitives : InheritableAttr {
+  let Spellings = [CXX11<"", "primitives", 2015>];
+  let Subjects = SubjectList<[ParmVar]>;
+  let Documentation = [Undocumented];
+}
+
+def HLSLPayload : InheritableAttr {
+  let Spellings = [CXX11<"", "payload", 2015>];
+  let Subjects = SubjectList<[ParmVar]>;
+  let Documentation = [Undocumented];
+}
+
 // HLSL Change Ends
 
 // SPIRV Change Starts

+ 10 - 4
tools/clang/include/clang/Basic/DiagnosticSemaKinds.td

@@ -7680,14 +7680,20 @@ def err_hlsl_no_struct_user_defined_type: Error<
    "User defined type intrinsic arg must be struct">;
 def err_hlsl_ray_desc_required: Error<
    "Argument type must be struct RayDesc.">;
-def err_hlsl_missing_maxvertexcount_attr: Error<
-   "GS entry point must have the maxvertexcount attribute">;
-def err_hlsl_missing_patchconstantfunc_attr: Error<
-   "HS entry point must have the patchconstantfunc attribute">;
+def err_hlsl_missing_attr: Error<
+   "%0 entry point must have the %1 attribute">;
 def err_hlsl_missing_inout_attr: Error<
    "stream-output object must be an inout parameter">;
 def err_hlsl_unsupported_string_decl: Error<
    "%select{array|parameter|return value}0 of type string is not supported">;
+def err_hlsl_missing_out_attr: Error<
+   "%0 object must be an out parameter">;
+def err_hlsl_missing_in_attr: Error<
+   "%0 object must be an in parameter">;
+def err_hlsl_load_from_mesh_out_arrays: Error<
+   "output arrays of a mesh shader can not be read from">;
+def err_hlsl_out_indices_array_incorrect_access: Error<
+   "a vector in out indices array must be accessed as a whole">;
 // HLSL Change Ends
 
 // SPIRV Change Starts

+ 4 - 0
tools/clang/include/clang/Basic/TokenKinds.def

@@ -513,6 +513,10 @@ KEYWORD(globallycoherent            , KEYHLSL)
 KEYWORD(interface                   , KEYHLSL)
 KEYWORD(sampler_state               , KEYHLSL)
 KEYWORD(technique                   , KEYHLSL)
+KEYWORD(indices                     , KEYHLSL)
+KEYWORD(vertices                    , KEYHLSL)
+KEYWORD(primitives                  , KEYHLSL)
+KEYWORD(payload                     , KEYHLSL)
 ALIAS("Technique", technique        , KEYHLSL)
 ALIAS("technique10", technique      , KEYHLSL)
 ALIAS("technique11", technique      , KEYHLSL)

+ 1 - 0
tools/clang/include/clang/SPIRV/FeatureManager.h

@@ -44,6 +44,7 @@ enum class Extension {
   GOOGLE_hlsl_functionality1,
   GOOGLE_user_type,
   NV_ray_tracing,
+  NV_mesh_shader,
   Unknown,
 };
 

+ 7 - 0
tools/clang/include/clang/SPIRV/SpirvBuilder.h

@@ -518,6 +518,13 @@ public:
   /// \brief Decorates the given target with NoContraction
   void decorateNoContraction(SpirvInstruction *target, SourceLocation);
 
+  /// \brief Decorates the given target with PerPrimitiveNV
+  void decoratePerPrimitiveNV(SpirvInstruction *target, SourceLocation);
+
+  /// \brief Decorates the given target with PerTaskNV
+  void decoratePerTaskNV(SpirvInstruction *target, uint32_t offset,
+                         SourceLocation);
+
   /// --- Constants ---
   /// Each of these methods can acquire a unique constant from the SpirvContext,
   /// and add the context to the list of constants in the module.

+ 4 - 0
tools/clang/include/clang/SPIRV/SpirvContext.h

@@ -228,6 +228,10 @@ public:
     return curShaderModelKind >= ShaderModelKind::RayGeneration &&
            curShaderModelKind <= ShaderModelKind::Callable;
   }
+  bool isMS() const { return curShaderModelKind == ShaderModelKind::Mesh; }
+  bool isAS() const {
+    return curShaderModelKind == ShaderModelKind::Amplification;
+  }
 
 private:
   /// \brief The allocator used to create SPIR-V entity objects.

+ 222 - 14
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -545,6 +545,15 @@ StringToTessOutputPrimitive(StringRef primitive) {
   return DXIL::TessellatorOutputPrimitive::Undefined;
 }
 
+static DXIL::MeshOutputTopology
+StringToMeshOutputTopology(StringRef topology) {
+  if (topology == "line")
+    return DXIL::MeshOutputTopology::Line;
+  if (topology == "triangle")
+    return DXIL::MeshOutputTopology::Triangle;
+  return DXIL::MeshOutputTopology::Undefined;
+}
+
 static unsigned RoundToAlign(unsigned num, unsigned mod) {
   // round num to next highest mod
   if (mod != 0)
@@ -1182,6 +1191,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
   bool isVS = false;
   bool isPS = false;
   bool isRay = false;
+  bool isMS = false;
+  bool isAS = false;
   if (const HLSLShaderAttr *Attr = FD->getAttr<HLSLShaderAttr>()) {
     // Stage is already validate in HandleDeclAttributeForHLSL.
     // Here just check first letter (or two).
@@ -1233,12 +1244,32 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
       funcProps->shaderKind = DXIL::ShaderKind::Intersection;
       break;
     case 'a':
-      isRay = true;
-      funcProps->shaderKind = DXIL::ShaderKind::AnyHit;
+      switch (Attr->getStage()[1]) {
+      case 'm':
+        isAS = true;
+        funcProps->shaderKind = DXIL::ShaderKind::Amplification;
+        break;
+      case 'n':
+        isRay = true;
+        funcProps->shaderKind = DXIL::ShaderKind::AnyHit;
+        break;
+      default:
+        break;
+      }
       break;
     case 'm':
-      isRay = true;
-      funcProps->shaderKind = DXIL::ShaderKind::Miss;
+      switch (Attr->getStage()[1]) {
+      case 'e':
+        isMS = true;
+        funcProps->shaderKind = DXIL::ShaderKind::Mesh;
+        break;
+      case 'i':
+        isRay = true;
+        funcProps->shaderKind = DXIL::ShaderKind::Miss;
+        break;
+      default:
+        break;
+      }
       break;
     default:
       break;
@@ -1289,6 +1320,12 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
   const ShaderModel *SM = m_pHLModule->GetShaderModel();
   if (isEntry) {
     funcProps->shaderKind = SM->GetKind();
+    if (funcProps->shaderKind == DXIL::ShaderKind::Mesh) {
+      isMS = true;
+    }
+    else if (funcProps->shaderKind == DXIL::ShaderKind::Amplification) {
+      isAS = true;
+    }
   }
 
   // Geometry shader.
@@ -1325,16 +1362,26 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
 
   // Computer shader.
   if (const HLSLNumThreadsAttr *Attr = FD->getAttr<HLSLNumThreadsAttr>()) {
-    isCS = true;
-    funcProps->shaderKind = DXIL::ShaderKind::Compute;
+    if (isMS) {
+      funcProps->ShaderProps.MS.numThreads[0] = Attr->getX();
+      funcProps->ShaderProps.MS.numThreads[1] = Attr->getY();
+      funcProps->ShaderProps.MS.numThreads[2] = Attr->getZ();
+    } else if (isAS) {
+      funcProps->ShaderProps.AS.numThreads[0] = Attr->getX();
+      funcProps->ShaderProps.AS.numThreads[1] = Attr->getY();
+      funcProps->ShaderProps.AS.numThreads[2] = Attr->getZ();
+    } else {
+      isCS = true;
+      funcProps->shaderKind = DXIL::ShaderKind::Compute;
 
-    funcProps->ShaderProps.CS.numThreads[0] = Attr->getX();
-    funcProps->ShaderProps.CS.numThreads[1] = Attr->getY();
-    funcProps->ShaderProps.CS.numThreads[2] = Attr->getZ();
+      funcProps->ShaderProps.CS.numThreads[0] = Attr->getX();
+      funcProps->ShaderProps.CS.numThreads[1] = Attr->getY();
+      funcProps->ShaderProps.CS.numThreads[2] = Attr->getZ();
+    }
 
-    if (isEntry && !SM->IsCS()) {
+    if (isEntry && !SM->IsCS() && !SM->IsMS() && !SM->IsAS()) {
       unsigned DiagID = Diags.getCustomDiagID(
-          DiagnosticsEngine::Error, "attribute numthreads only valid for CS.");
+          DiagnosticsEngine::Error, "attribute numthreads only valid for CS/MS/AS.");
       Diags.Report(Attr->getLocation(), DiagID);
       return;
     }
@@ -1398,10 +1445,16 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
       DXIL::TessellatorOutputPrimitive primitive =
           StringToTessOutputPrimitive(Attr->getTopology());
       funcProps->ShaderProps.HS.outputPrimitive = primitive;
-    } else if (isEntry && !SM->IsHS()) {
+    }
+    else if (isMS) {
+      DXIL::MeshOutputTopology topology =
+          StringToMeshOutputTopology(Attr->getTopology());
+      funcProps->ShaderProps.MS.outputTopology = topology;
+    }
+    else if (isEntry && !SM->IsHS() && !SM->IsMS()) {
       unsigned DiagID =
           Diags.getCustomDiagID(DiagnosticsEngine::Warning,
-                                "attribute outputtopology only valid for HS.");
+                                "attribute outputtopology only valid for HS and MS.");
       Diags.Report(Attr->getLocation(), DiagID);
     }
   }
@@ -1478,7 +1531,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
     funcProps->shaderKind = DXIL::ShaderKind::Pixel;
   }
 
-  const unsigned profileAttributes = isCS + isHS + isDS + isGS + isVS + isPS + isRay;
+  const unsigned profileAttributes = isCS + isHS + isDS + isGS + isVS + isPS + isRay + isMS + isAS;
 
   // TODO: check this in front-end and report error.
   DXASSERT(profileAttributes < 2, "profile attributes are mutual exclusive");
@@ -1491,6 +1544,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
     case ShaderModel::Kind::Geometry:
     case ShaderModel::Kind::Vertex:
     case ShaderModel::Kind::Pixel:
+    case ShaderModel::Kind::Mesh:
+    case ShaderModel::Kind::Amplification:
       DXASSERT(funcProps->shaderKind == SM->GetKind(),
                "attribute profile not match entry function profile");
       break;
@@ -1561,6 +1616,10 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
     funcProps->ShaderProps.Ray.attributeSizeInBytes = 0;
   }
 
+  bool hasOutIndices = false;
+  bool hasOutVertices = false;
+  bool hasOutPrimitives = false;
+  bool hasInPayload = false;
   for (; ArgNo < F->arg_size(); ++ArgNo, ++ParmIdx) {
     DxilParameterAnnotation &paramAnnotation =
         FuncAnnotation->GetParameterAnnotation(ArgNo);
@@ -1591,6 +1650,155 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
     if (parmDecl->hasAttr<HLSLOutAttr>() && parmDecl->hasAttr<HLSLInAttr>())
       dxilInputQ = DxilParamInputQual::Inout;
 
+    if (parmDecl->hasAttr<HLSLOutAttr>() && parmDecl->hasAttr<HLSLIndicesAttr>()) {
+      if (hasOutIndices) {
+        unsigned DiagID = Diags.getCustomDiagID(
+            DiagnosticsEngine::Error,
+            "multiple out indices parameters not allowed");
+        Diags.Report(parmDecl->getLocation(), DiagID);
+        continue;
+      }
+      const ConstantArrayType *CAT = dyn_cast<ConstantArrayType>(fieldTy.getCanonicalType());
+      if (CAT == nullptr) {
+        unsigned DiagID = Diags.getCustomDiagID(
+          DiagnosticsEngine::Error,
+          "indices output is not an constant-length array");
+        Diags.Report(parmDecl->getLocation(), DiagID);
+        continue;
+      }
+      unsigned count = CAT->getSize().getZExtValue();
+      if (count > DXIL::kMaxMSOutputPrimitiveCount) {
+        unsigned DiagID = Diags.getCustomDiagID(
+            DiagnosticsEngine::Error,
+            "max primitive count should not exceed %0");
+        Diags.Report(parmDecl->getLocation(), DiagID) << DXIL::kMaxMSOutputPrimitiveCount;
+        continue;
+      }
+      if (funcProps->ShaderProps.MS.maxPrimitiveCount != 0 &&
+        funcProps->ShaderProps.MS.maxPrimitiveCount != count) {
+        unsigned DiagID = Diags.getCustomDiagID(
+            DiagnosticsEngine::Error,
+            "max primitive count mismatch");
+        Diags.Report(parmDecl->getLocation(), DiagID);
+        continue;
+      }
+      // Get element type.
+      QualType arrayEleTy = CAT->getElementType();
+
+      if (hlsl::IsHLSLVecType(arrayEleTy)) {
+        QualType vecEltTy = hlsl::GetHLSLVecElementType(arrayEleTy);
+        if (!vecEltTy->isUnsignedIntegerType() || CGM.getContext().getTypeSize(vecEltTy) != 32) {
+          unsigned DiagID = Diags.getCustomDiagID(
+              DiagnosticsEngine::Error,
+              "the element of out_indices array must be uint2 for line output or uint3 for triangle output");
+          Diags.Report(parmDecl->getLocation(), DiagID);
+          continue;
+        }
+        unsigned vecEltCount = hlsl::GetHLSLVecSize(arrayEleTy);
+        if (funcProps->ShaderProps.MS.outputTopology == DXIL::MeshOutputTopology::Line && vecEltCount != 2) {
+          unsigned DiagID = Diags.getCustomDiagID(
+              DiagnosticsEngine::Error,
+              "the element of out_indices array in a mesh shader whose output topology is line must be uint2");
+          Diags.Report(parmDecl->getLocation(), DiagID);
+          continue;
+        }
+        if (funcProps->ShaderProps.MS.outputTopology == DXIL::MeshOutputTopology::Triangle && vecEltCount != 3) {
+          unsigned DiagID = Diags.getCustomDiagID(
+              DiagnosticsEngine::Error,
+              "the element of out_indices array in a mesh shader whose output topology is triangle must be uint3");
+          Diags.Report(parmDecl->getLocation(), DiagID);
+          continue;
+        }
+      } else {
+        unsigned DiagID = Diags.getCustomDiagID(
+            DiagnosticsEngine::Error,
+            "the element of out_indices array must be uint2 for line output or uint3 for triangle output");
+        Diags.Report(parmDecl->getLocation(), DiagID);
+        continue;
+      }
+
+      dxilInputQ = DxilParamInputQual::OutIndices;
+      funcProps->ShaderProps.MS.maxPrimitiveCount = count;
+      hasOutIndices = true;
+    }
+    if (parmDecl->hasAttr<HLSLOutAttr>() && parmDecl->hasAttr<HLSLVerticesAttr>()) {
+      if (hasOutVertices) {
+        unsigned DiagID = Diags.getCustomDiagID(
+            DiagnosticsEngine::Error,
+            "multiple out vertices parameters not allowed");
+        Diags.Report(parmDecl->getLocation(), DiagID);
+        continue;
+      }
+      const ConstantArrayType *CAT = dyn_cast<ConstantArrayType>(fieldTy.getCanonicalType());
+      if (CAT == nullptr) {
+        unsigned DiagID = Diags.getCustomDiagID(
+            DiagnosticsEngine::Error,
+            "vertices output is not an constant-length array");
+        Diags.Report(parmDecl->getLocation(), DiagID);
+        continue;
+      }
+      unsigned count = CAT->getSize().getZExtValue();
+      if (count > DXIL::kMaxMSOutputVertexCount) {
+        unsigned DiagID = Diags.getCustomDiagID(
+            DiagnosticsEngine::Error,
+            "max vertex count should not exceed %0");
+        Diags.Report(parmDecl->getLocation(), DiagID) << DXIL::kMaxMSOutputVertexCount;
+        continue;
+      }
+
+      dxilInputQ = DxilParamInputQual::OutVertices;
+      funcProps->ShaderProps.MS.maxVertexCount = count;
+      hasOutVertices = true;
+    }
+    if (parmDecl->hasAttr<HLSLOutAttr>() && parmDecl->hasAttr<HLSLPrimitivesAttr>()) {
+      if (hasOutPrimitives) {
+        unsigned DiagID = Diags.getCustomDiagID(
+            DiagnosticsEngine::Error,
+            "multiple out primitives parameters not allowed");
+        Diags.Report(parmDecl->getLocation(), DiagID);
+        continue;
+      }
+      const ConstantArrayType *CAT = dyn_cast<ConstantArrayType>(fieldTy.getCanonicalType());
+      if (CAT == nullptr) {
+        unsigned DiagID = Diags.getCustomDiagID(
+            DiagnosticsEngine::Error,
+            "primitives output is not an constant-length array");
+        Diags.Report(parmDecl->getLocation(), DiagID);
+        continue;
+      }
+      unsigned count = CAT->getSize().getZExtValue();
+      if (count > DXIL::kMaxMSOutputPrimitiveCount) {
+        unsigned DiagID = Diags.getCustomDiagID(
+            DiagnosticsEngine::Error,
+            "max primitive count should not exceed %0");
+        Diags.Report(parmDecl->getLocation(), DiagID) << DXIL::kMaxMSOutputPrimitiveCount;
+        continue;
+      }
+      if (funcProps->ShaderProps.MS.maxPrimitiveCount != 0 &&
+        funcProps->ShaderProps.MS.maxPrimitiveCount != count) {
+        unsigned DiagID = Diags.getCustomDiagID(
+            DiagnosticsEngine::Error,
+            "max primitive count mismatch");
+        Diags.Report(parmDecl->getLocation(), DiagID);
+        continue;
+      }
+
+      dxilInputQ = DxilParamInputQual::OutPrimitives;
+      funcProps->ShaderProps.MS.maxPrimitiveCount = count;
+      hasOutPrimitives = true;
+    }
+    if (parmDecl->hasAttr<HLSLInAttr>() && parmDecl->hasAttr<HLSLPayloadAttr>()) {
+      if (hasInPayload) {
+        unsigned DiagID = Diags.getCustomDiagID(
+            DiagnosticsEngine::Error,
+            "multiple in payload parameters not allowed");
+        Diags.Report(parmDecl->getLocation(), DiagID);
+        continue;
+      }
+      dxilInputQ = DxilParamInputQual::InPayload;
+      hasInPayload = true;
+    }
+
     DXIL::InputPrimitive inputPrimitive = DXIL::InputPrimitive::Undefined;
 
     if (IsHLSLOutputPatchType(parmDecl->getType())) {

+ 21 - 0
tools/clang/lib/Parse/ParseDecl.cpp

@@ -763,6 +763,10 @@ void Parser::ParseGNUAttributeArgs(IdentifierInfo *AttrName,
     //case AttributeList::AT_HLSLLineAdj:
     //case AttributeList::AT_HLSLTriangle:
     //case AttributeList::AT_HLSLTriangleAdj:
+    //case AttributeList::AT_HLSLIndices:
+    //case AttributeList::AT_HLSLVertices:
+    //case AttributeList::AT_HLSLPrimitives:
+    //case AttributeList::AT_HLSLPayload:
       goto GenericAttributeParse;
     default:
       Diag(AttrNameLoc, diag::err_hlsl_unsupported_construct) << AttrName;
@@ -3786,9 +3790,15 @@ HLSLReservedKeyword:
     case tok::kw_sample:
     case tok::kw_globallycoherent:
     case tok::kw_center:
+    case tok::kw_indices:
+    case tok::kw_vertices:
+    case tok::kw_primitives:
+    case tok::kw_payload:
       // Back-compat: 'precise', 'globallycoherent', 'center' and 'sample' are keywords when used as an interpolation
       // modifiers, but in FXC they can also be used an identifiers. If the decl type has already been specified
       // we need to update the token to be handled as an identifier.
+      // Similarly 'indices', 'vertices', 'primitives' and 'payload' are keywords
+      // when used as a type qualifer in mesh shader, but may still be used as a variable name.
       if (getLangOpts().HLSL) {
         if (DS.getTypeSpecType() != DeclSpec::TST_unspecified) {
           Tok.setKind(tok::identifier);
@@ -5236,6 +5246,10 @@ bool Parser::isDeclarationSpecifier(bool DisambiguatingWithExpression) {
   case tok::kw_triangle:
   case tok::kw_triangleadj:
   case tok::kw_export:
+  case tok::kw_indices:
+  case tok::kw_vertices:
+  case tok::kw_primitives:
+  case tok::kw_payload:
     return true;
   // HLSL Change Ends
 
@@ -6006,6 +6020,9 @@ void Parser::ParseDirectDeclarator(Declarator &D) {
     // FXC they can also be used an identifiers. If the next token is a
     // punctuator, then we are using them as identifers. Need to change
     // the token type to tok::identifier and fall through to the next case.
+    // Similarly 'indices', 'vertices', 'primitives' and 'payload' are keywords
+    // when used as a type qualifer in mesh shader, but may still be used as a
+    // variable name.
     // E.g., <type> left, center, right;
     if (getLangOpts().HLSL) {
       switch (Tok.getKind()) {
@@ -6013,6 +6030,10 @@ void Parser::ParseDirectDeclarator(Declarator &D) {
       case tok::kw_globallycoherent:
       case tok::kw_precise:
       case tok::kw_sample:
+      case tok::kw_indices:
+      case tok::kw_vertices:
+      case tok::kw_primitives:
+      case tok::kw_payload:
         if (tok::isPunctuator(NextToken().getKind()))
           Tok.setKind(tok::identifier);
         break;

+ 10 - 0
tools/clang/lib/Parse/ParseExpr.cpp

@@ -794,9 +794,15 @@ HLSLReservedKeyword:
   case tok::kw_sample:
   case tok::kw_globallycoherent:
   case tok::kw_center:
+  case tok::kw_indices:
+  case tok::kw_vertices:
+  case tok::kw_primitives:
+  case tok::kw_payload:
     // Back-compat: 'precise', 'globallycoherent', 'center' and 'sample' are keywords when used as an interpolation 
     // modifiers, but in FXC they can also be used an identifiers. No interpolation modifiers are expected here
     // so we need to change the token type to tok::identifier and fall through to the next case.
+    // Similarly 'indices', 'vertices', 'primitives' and 'payload' are keywords when used
+    // as a type qualifer in mesh shader, but may still be used as a variable name.
     Tok.setKind(tok::identifier);
     __fallthrough;
     // HLSL Change Ends
@@ -1722,6 +1728,10 @@ Parser::ParsePostfixExpressionSuffix(ExprResult LHS) {
         case tok::kw_globallycoherent:
         case tok::kw_precise:
         case tok::kw_sample:
+        case tok::kw_indices:
+        case tok::kw_vertices:
+        case tok::kw_primitives:
+        case tok::kw_payload:
           Tok.setKind(tok::identifier);
           Tok.setIdentifierInfo(PP.getIdentifierInfo(getKeywordSpelling(tk)));
           break;

+ 5 - 1
tools/clang/lib/Parse/ParseStmt.cpp

@@ -179,7 +179,11 @@ Retry:
   case tok::kw_precise:
   case tok::kw_sample:
   case tok::kw_globallycoherent:
-  case tok::kw_center: {
+  case tok::kw_center:
+  case tok::kw_indices:
+  case tok::kw_vertices:
+  case tok::kw_primitives:
+  case tok::kw_payload: {
     // FXC compatiblity: these are keywords when used as modifiers, but in
     // FXC they can also be used an identifiers. If the next token is a
     // punctuator, then we are using them as identifers. Need to change

+ 4 - 0
tools/clang/lib/Parse/ParseTentative.cpp

@@ -1275,6 +1275,10 @@ Parser::isCXXDeclarationSpecifier(Parser::TPResult BracedCastResult,
   case tok::kw_precise:
   case tok::kw_center:
   case tok::kw_globallycoherent:
+  case tok::kw_indices:
+  case tok::kw_vertices:
+  case tok::kw_primitives:
+  case tok::kw_payload:
     // FXC compatiblity: these are keywords when used as modifiers, but in
     // FXC they can also be used an identifiers. If the next token is a
     // punctuator, then we are using them as identifers. Need to change

+ 14 - 6
tools/clang/lib/SPIRV/CapabilityVisitor.cpp

@@ -273,8 +273,9 @@ bool CapabilityVisitor::visit(SpirvDecoration *decor) {
       break;
     }
     case spv::BuiltIn::PrimitiveId: {
-      // PrimitiveID can be used as PSIn
-      if (shaderModel == spv::ExecutionModel::Fragment)
+      // PrimitiveID can be used as PSIn or MSPOut.
+      if (shaderModel == spv::ExecutionModel::Fragment ||
+          shaderModel == spv::ExecutionModel::MeshNV)
         addCapability(spv::Capability::Geometry);
       break;
     }
@@ -285,8 +286,9 @@ bool CapabilityVisitor::visit(SpirvDecoration *decor) {
         addExtension(Extension::EXT_shader_viewport_index_layer,
                      "SV_RenderTargetArrayIndex", loc);
         addCapability(spv::Capability::ShaderViewportIndexLayerEXT);
-      } else if (shaderModel == spv::ExecutionModel::Fragment) {
-        // SV_RenderTargetArrayIndex can be used as PSIn.
+      } else if (shaderModel == spv::ExecutionModel::Fragment ||
+                 shaderModel == spv::ExecutionModel::MeshNV) {
+        // SV_RenderTargetArrayIndex can be used as PSIn or MSPOut.
         addCapability(spv::Capability::Geometry);
       }
       break;
@@ -299,8 +301,9 @@ bool CapabilityVisitor::visit(SpirvDecoration *decor) {
                      "SV_ViewPortArrayIndex", loc);
         addCapability(spv::Capability::ShaderViewportIndexLayerEXT);
       } else if (shaderModel == spv::ExecutionModel::Fragment ||
-                 shaderModel == spv::ExecutionModel::Geometry) {
-        // SV_ViewportArrayIndex can be used as PSIn.
+                 shaderModel == spv::ExecutionModel::Geometry ||
+                 shaderModel == spv::ExecutionModel::MeshNV) {
+        // SV_ViewportArrayIndex can be used as PSIn or GSOut or MSPOut.
         addCapability(spv::Capability::MultiViewport);
       }
       break;
@@ -503,6 +506,11 @@ bool CapabilityVisitor::visit(SpirvEntryPoint *entryPoint) {
     addCapability(spv::Capability::RayTracingNV);
     addExtension(Extension::NV_ray_tracing, "SPV_NV_ray_tracing", {});
     break;
+  case spv::ExecutionModel::MeshNV:
+  case spv::ExecutionModel::TaskNV:
+    addCapability(spv::Capability::MeshShadingNV);
+    addExtension(Extension::NV_mesh_shader, "SPV_NV_mesh_shader", {});
+    break;
   default:
     llvm_unreachable("found unknown shader model");
     break;

+ 243 - 49
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -21,6 +21,7 @@
 #include "llvm/ADT/StringSet.h"
 #include "llvm/Support/Casting.h"
 
+#include "AlignmentSizeCalculator.h"
 #include "SpirvEmitter.h"
 
 namespace clang {
@@ -295,6 +296,15 @@ hlsl::DxilParamInputQual deduceParamQual(const DeclaratorDecl *decl,
   if (hasGSPrimitiveTypeQualifier(decl))
     return hlsl::DxilParamInputQual::InputPrimitive;
 
+  if (decl->hasAttr<HLSLIndicesAttr>())
+    return hlsl::DxilParamInputQual::OutIndices;
+  if (decl->hasAttr<HLSLVerticesAttr>())
+    return hlsl::DxilParamInputQual::OutVertices;
+  if (decl->hasAttr<HLSLPrimitivesAttr>())
+    return hlsl::DxilParamInputQual::OutPrimitives;
+  if (decl->hasAttr<HLSLPayloadAttr>())
+    return hlsl::DxilParamInputQual::InPayload;
+
   return asInput ? hlsl::DxilParamInputQual::In : hlsl::DxilParamInputQual::Out;
 }
 
@@ -432,12 +442,35 @@ bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl,
                                               SpirvInstruction *storedValue,
                                               bool forPCF) {
   QualType type = getTypeOrFnRetType(decl);
+  uint32_t arraySize = 0;
 
   // Output stream types (PointStream, LineStream, TriangleStream) are
   // translated as their underlying struct types.
   if (hlsl::IsHLSLStreamOutputType(type))
     type = hlsl::GetHLSLResourceResultType(type);
 
+  if (decl->hasAttr<HLSLIndicesAttr>() || decl->hasAttr<HLSLVerticesAttr>() ||
+      decl->hasAttr<HLSLPrimitivesAttr>()) {
+    const auto *typeDecl = astContext.getAsConstantArrayType(type);
+    type = typeDecl->getElementType();
+    arraySize = static_cast<uint32_t>(typeDecl->getSize().getZExtValue());
+    if (decl->hasAttr<HLSLIndicesAttr>()) {
+      // create SPIR-V builtin array PrimitiveIndicesNV of type
+      // "uint [MaxPrimitiveCount * verticesPerPrim]"
+      uint32_t verticesPerPrim = 1;
+      if (!isVectorType(type, nullptr, &verticesPerPrim)) {
+        assert(isScalarType(type));
+      }
+      arraySize = arraySize * verticesPerPrim;
+      QualType arrayType = astContext.getConstantArrayType(
+          astContext.UnsignedIntTy, llvm::APInt(32, arraySize),
+          clang::ArrayType::Normal, 0);
+      stageVarInstructions[cast<DeclaratorDecl>(decl)] = getBuiltinVar(
+          spv::BuiltIn::PrimitiveIndicesNV, arrayType, decl->getLocation());
+      return true;
+    }
+  }
+
   const auto *sigPoint = deduceSigPoint(
       decl, /*asInput=*/false, spvContext.getCurrentShaderModelKind(), forPCF);
 
@@ -453,11 +486,12 @@ bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl,
   // Write back of stage output variables in GS is manually controlled by
   // .Append() intrinsic method, implemented in writeBackOutputStream(). So
   // ignoreValue should be set to true for GS.
-  const bool noWriteBack = storedValue == nullptr || spvContext.isGS();
+  const bool noWriteBack =
+      storedValue == nullptr || spvContext.isGS() || spvContext.isMS();
 
-  return createStageVars(sigPoint, decl, /*asInput=*/false, type,
-                         /*arraySize=*/0, "out.var", llvm::None, &storedValue,
-                         noWriteBack, &inheritSemantic);
+  return createStageVars(sigPoint, decl, /*asInput=*/false, type, arraySize,
+                         "out.var", llvm::None, &storedValue, noWriteBack,
+                         &inheritSemantic);
 }
 
 bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl,
@@ -505,9 +539,15 @@ bool DeclResultIdMapper::createStageInputVar(const ParmVarDecl *paramDecl,
 
   SemanticInfo inheritSemantic = {};
 
-  return createStageVars(sigPoint, paramDecl, /*asInput=*/true, type, arraySize,
-                         "in.var", llvm::None, loadedValue,
-                         /*noWriteBack=*/false, &inheritSemantic);
+  if (paramDecl->hasAttr<HLSLPayloadAttr>()) {
+    spv::StorageClass sc = getStorageClassForSigPoint(sigPoint);
+    return createPayloadStageVars(sigPoint, sc, paramDecl, /*asInput=*/true,
+                                  type, "in.var", loadedValue);
+  } else {
+    return createStageVars(sigPoint, paramDecl, /*asInput=*/true, type,
+                           arraySize, "in.var", llvm::None, loadedValue,
+                           /*noWriteBack=*/false, &inheritSemantic);
+  }
 }
 
 const DeclResultIdMapper::DeclSpirvInfo *
@@ -1193,9 +1233,10 @@ bool DeclResultIdMapper::checkSemanticDuplication(bool forInput) {
     auto s = var.getSemanticStr();
 
     if (s.empty()) {
-      // We translate WaveGetLaneCount() and WaveGetLaneIndex() into builtin
-      // variables. Those variables are inserted into the normal stage IO
-      // processing pipeline, but with the semantics as empty strings.
+      // We translate WaveGetLaneCount(), WaveGetLaneIndex() and 'payload' param
+      // block declaration into builtin variables. Those variables are inserted
+      // into the normal stage IO processing pipeline, but with the semantics as
+      // empty strings.
       assert(var.isSpirvBuitin());
       continue;
     }
@@ -1671,17 +1712,31 @@ bool DeclResultIdMapper::createStageVars(
       return false;
 
     const auto semanticKind = semanticToUse->getKind();
+    const auto sigPointKind = sigPoint->GetKind();
 
     // Error out when the given semantic is invalid in this shader model
-    if (hlsl::SigPoint::GetInterpretation(semanticKind, sigPoint->GetKind(),
+    if (hlsl::SigPoint::GetInterpretation(semanticKind, sigPointKind,
                                           spvContext.getMajorVersion(),
                                           spvContext.getMinorVersion()) ==
         hlsl::DXIL::SemanticInterpretationKind::NA) {
-      emitError("invalid usage of semantic '%0' in shader profile %1", loc)
-          << semanticToUse->str
-          << hlsl::ShaderModel::GetKindName(
-                 spvContext.getCurrentShaderModelKind());
-      return false;
+      // Special handle MSIn/ASIn allowing VK-only builtin "DrawIndex".
+      switch (sigPointKind) {
+      case hlsl::SigPoint::Kind::MSIn:
+      case hlsl::SigPoint::Kind::ASIn:
+        if (const auto *builtinAttr = decl->getAttr<VKBuiltInAttr>()) {
+          const llvm::StringRef builtin = builtinAttr->getBuiltIn();
+          if (builtin == "DrawIndex") {
+            break;
+          }
+        }
+        // fall through
+      default:
+        emitError("invalid usage of semantic '%0' in shader profile %1", loc)
+            << semanticToUse->str
+            << hlsl::ShaderModel::GetKindName(
+                   spvContext.getCurrentShaderModelKind());
+        return false;
+      }
     }
 
     if (!validateVKBuiltins(decl, sigPoint))
@@ -1711,9 +1766,9 @@ bool DeclResultIdMapper::createStageVars(
     // * SV_ShadingRate is a uint value, but the builtin it corresponds to is a
     //   int2.
 
-    if (glPerVertex.tryToAccess(sigPoint->GetKind(), semanticKind,
+    if (glPerVertex.tryToAccess(sigPointKind, semanticKind,
                                 semanticToUse->index, invocationId, value,
-                                noWriteBack, loc))
+                                noWriteBack, /*vecComponent=*/nullptr, loc))
       return true;
 
     switch (semanticKind) {
@@ -1757,7 +1812,7 @@ bool DeclResultIdMapper::createStageVars(
 
     // Boolean stage I/O variables must be represented as unsigned integers.
     // Boolean built-in variables are represented as bool.
-    if (isBooleanStageIOVar(decl, type, semanticKind, sigPoint->GetKind())) {
+    if (isBooleanStageIOVar(decl, type, semanticKind, sigPointKind)) {
       evalType = getUintTypeWithSourceComponents(astContext, type);
     }
 
@@ -1796,8 +1851,15 @@ bool DeclResultIdMapper::createStageVars(
 
     // TODO: the following may not be correct?
     if (sigPoint->GetSignatureKind() ==
-        hlsl::DXIL::SignatureKind::PatchConstant)
-      spvBuilder.decoratePatch(varInstr, varInstr->getSourceLocation());
+        hlsl::DXIL::SignatureKind::PatchConstOrPrim) {
+      if (sigPointKind == hlsl::SigPoint::Kind::MSPOut) {
+        // Decorate with PerPrimitiveNV for per-primitive out variables.
+        spvBuilder.decoratePerPrimitiveNV(varInstr,
+                                          varInstr->getSourceLocation());
+      } else {
+        spvBuilder.decoratePatch(varInstr, varInstr->getSourceLocation());
+      }
+    }
 
     // Decorate with interpolation modes for pixel shader input variables
     if (spvContext.isPS() && sigPoint->IsInput() &&
@@ -1947,7 +2009,7 @@ bool DeclResultIdMapper::createStageVars(
 
       // Since boolean stage input variables are represented as unsigned
       // integers, after loading them, we should cast them to boolean.
-      if (isBooleanStageIOVar(decl, type, semanticKind, sigPoint->GetKind())) {
+      if (isBooleanStageIOVar(decl, type, semanticKind, sigPointKind)) {
         *value =
             theEmitter.castToType(*value, evalType, type, thisSemantic.loc);
       }
@@ -2026,8 +2088,7 @@ bool DeclResultIdMapper::createStageVars(
       }
       // Since boolean output stage variables are represented as unsigned
       // integers, we must cast the value to uint before storing.
-      else if (isBooleanStageIOVar(decl, type, semanticKind,
-                                   sigPoint->GetKind())) {
+      else if (isBooleanStageIOVar(decl, type, semanticKind, sigPointKind)) {
         *value =
             theEmitter.castToType(*value, type, evalType, thisSemantic.loc);
         spvBuilder.createStore(ptr, *value, thisSemantic.loc);
@@ -2153,6 +2214,7 @@ bool DeclResultIdMapper::createStageVars(
     //       but we only write to the struct at the InvocationID index
     // * DS: output is a single struct, without extra arrayness
     // * GS: output is controlled by OpEmitVertex, one vertex per time
+    // * MS: output is an array of structs, with extra arrayness
     //
     // The interesting shader stage is HS. We need the InvocationID to write
     // out the value to the correct array element.
@@ -2174,6 +2236,96 @@ bool DeclResultIdMapper::createStageVars(
   return true;
 }
 
+bool DeclResultIdMapper::createPayloadStageVars(
+    const hlsl::SigPoint *sigPoint, spv::StorageClass sc, const NamedDecl *decl,
+    bool asInput, QualType type, const llvm::StringRef namePrefix,
+    SpirvInstruction **value, uint32_t payloadMemOffset) {
+  assert(spvContext.isMS() || spvContext.isAS());
+  assert(value);
+
+  if (type->isVoidType()) {
+    // No stage variables will be created for void type.
+    return true;
+  }
+
+  const auto loc = decl->getLocation();
+  if (!type->isStructureType()) {
+    StageVar stageVar(sigPoint, /*semaInfo=*/{}, /*builtinAttr=*/nullptr, type,
+                      getLocationCount(astContext, type));
+    const auto name = namePrefix.str() + "." + decl->getNameAsString();
+    SpirvVariable *varInstr =
+        spvBuilder.addStageIOVar(type, sc, name, /*isPrecise=*/false);
+
+    if (!varInstr)
+      return false;
+
+    // Even though these as user defined IO stage variables, set them as SPIR-V
+    // builtins in order to bypass any semantic string checks and location
+    // assignment.
+    stageVar.setIsSpirvBuiltin();
+    stageVar.setSpirvInstr(varInstr);
+    stageVars.push_back(stageVar);
+
+    // Decorate with PerTaskNV for mesh/amplification shader payload variables.
+    spvBuilder.decoratePerTaskNV(varInstr, payloadMemOffset,
+                                 varInstr->getSourceLocation());
+
+    if (asInput) {
+      *value = spvBuilder.createLoad(type, varInstr, loc);
+    } else {
+      spvBuilder.createStore(varInstr, *value, loc);
+    }
+    return true;
+  }
+
+  // This decl translates into multiple stage input/output payload variables
+  // and we need to load/store these individual member variables.
+  const auto *structDecl = type->getAs<RecordType>()->getDecl();
+  llvm::SmallVector<SpirvInstruction *, 4> subValues;
+  AlignmentSizeCalculator alignmentCalc(astContext, spirvOptions);
+  uint32_t nextMemberOffset = 0;
+
+  for (const auto *field : structDecl->fields()) {
+    const auto fieldType = field->getType();
+    SpirvInstruction *subValue = nullptr;
+    uint32_t memberAlignment = 0, memberSize = 0, stride = 0;
+
+    // The next avaiable offset after laying out the previous members.
+    std::tie(memberAlignment, memberSize) = alignmentCalc.getAlignmentAndSize(
+        field->getType(), spirvOptions.ampPayloadLayoutRule,
+        /*isRowMajor*/ llvm::None, &stride);
+    alignmentCalc.alignUsingHLSLRelaxedLayout(
+        field->getType(), memberSize, memberAlignment, &nextMemberOffset);
+
+    // The vk::offset attribute takes precedence over all.
+    if (field->getAttr<VKOffsetAttr>()) {
+      nextMemberOffset = field->getAttr<VKOffsetAttr>()->getOffset();
+    }
+
+    // Each payload member must have an Offset Decoration.
+    payloadMemOffset = nextMemberOffset;
+    nextMemberOffset += memberSize;
+
+    if (!asInput) {
+      subValue = spvBuilder.createCompositeExtract(
+          fieldType, *value, {getNumBaseClasses(type) + field->getFieldIndex()},
+          loc);
+    }
+
+    if (!createPayloadStageVars(sigPoint, sc, field, asInput, field->getType(),
+                                namePrefix, &subValue, payloadMemOffset))
+      return false;
+
+    if (asInput) {
+      subValues.push_back(subValue);
+    }
+  }
+  if (asInput) {
+    *value = spvBuilder.createCompositeConstruct(type, subValues, loc);
+  }
+  return true;
+}
+
 bool DeclResultIdMapper::writeBackOutputStream(const NamedDecl *decl,
                                                QualType type,
                                                SpirvInstruction *value) {
@@ -2191,11 +2343,11 @@ bool DeclResultIdMapper::writeBackOutputStream(const NamedDecl *decl,
     // Found semantic attached directly to this Decl. Write the value for this
     // Decl to the corresponding stage output variable.
 
-    // Handle SV_Position, SV_ClipDistance, and SV_CullDistance
-    if (glPerVertex.tryToAccess(hlsl::DXIL::SigPointKind::GSOut,
-                                semanticInfo.semantic->GetKind(),
-                                semanticInfo.index, llvm::None, &value,
-                                /*noWriteBack=*/false, loc))
+    // Handle SV_ClipDistance, and SV_CullDistance
+    if (glPerVertex.tryToAccess(
+            hlsl::DXIL::SigPointKind::GSOut, semanticInfo.semantic->GetKind(),
+            semanticInfo.index, llvm::None, &value,
+            /*noWriteBack=*/false, /*vecComponent=*/nullptr, loc))
       return true;
 
     // Query the <result-id> for the stage output variable generated out
@@ -2335,6 +2487,8 @@ SpirvVariable *DeclResultIdMapper::getBuiltinVar(spv::BuiltIn builtIn,
   if (builtInVar != builtinToVarMap.end()) {
     return builtInVar->second;
   }
+  spv::StorageClass sc = spv::StorageClass::Max;
+  // Valid builtins supported
   switch (builtIn) {
   case spv::BuiltIn::SubgroupSize:
   case spv::BuiltIn::SubgroupLocalInvocationId:
@@ -2353,7 +2507,12 @@ SpirvVariable *DeclResultIdMapper::getBuiltinVar(spv::BuiltIn builtIn,
   case spv::BuiltIn::WorldToObjectNV:
   case spv::BuiltIn::LaunchIdNV:
   case spv::BuiltIn::LaunchSizeNV:
-    // Valid builtins supported
+    sc = spv::StorageClass::Input;
+    break;
+  case spv::BuiltIn::PrimitiveCountNV:
+  case spv::BuiltIn::PrimitiveIndicesNV:
+  case spv::BuiltIn::TaskCountNV:
+    sc = spv::StorageClass::Output;
     break;
   default:
     assert(false && "unsupported SPIR-V builtin");
@@ -2361,8 +2520,8 @@ SpirvVariable *DeclResultIdMapper::getBuiltinVar(spv::BuiltIn builtIn,
   }
 
   // Create a dummy StageVar for this builtin variable
-  auto var = spvBuilder.addStageBuiltinVar(type, spv::StorageClass::Input,
-                                           builtIn, /*isPrecise*/ false, loc);
+  auto var = spvBuilder.addStageBuiltinVar(type, sc, builtIn,
+                                           /*isPrecise*/ false, loc);
 
   const hlsl::SigPoint *sigPoint =
       hlsl::SigPoint::GetSigPoint(hlsl::SigPointFromInputQual(
@@ -2407,6 +2566,7 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
             .Case("BaseInstance", BuiltIn::BaseInstance)
             .Case("DrawIndex", BuiltIn::DrawIndex)
             .Case("DeviceIndex", BuiltIn::DeviceIndex)
+            .Case("ViewportMaskNV", BuiltIn::ViewportMaskNV)
             .Default(BuiltIn::Max);
 
     assert(spvBuiltIn != BuiltIn::Max); // The frontend should guarantee this.
@@ -2419,9 +2579,9 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
   // each semantic.
   switch (semanticKind) {
   // According to DXIL spec, the Position SV can be used by all SigPoints
-  // other than PCIn, HSIn, GSIn, PSOut, CSIn.
+  // other than PCIn, HSIn, GSIn, PSOut, CSIn, MSIn, MSPOut, ASIn.
   // According to Vulkan spec, the Position BuiltIn can only be used
-  // by VSOut, HS/DS/GS In/Out.
+  // by VSOut, HS/DS/GS In/Out, MSOut.
   case hlsl::Semantic::Kind::Position: {
     switch (sigPointKind) {
     case hlsl::SigPoint::Kind::VSIn:
@@ -2435,6 +2595,7 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
     case hlsl::SigPoint::Kind::DSOut:
     case hlsl::SigPoint::Kind::GSVIn:
     case hlsl::SigPoint::Kind::GSOut:
+    case hlsl::SigPoint::Kind::MSOut:
       stageVar->setIsSpirvBuiltin();
       return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::Position,
                                            isPrecise, srcLoc);
@@ -2497,9 +2658,9 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
                                          isPrecise, srcLoc);
   }
   // According to DXIL spec, the ClipDistance/CullDistance SV can be used by all
-  // SigPoints other than PCIn, HSIn, GSIn, PSOut, CSIn.
-  // According to Vulkan spec, the ClipDistance/CullDistance BuiltIn can only be
-  // used by VSOut, HS/DS/GS In/Out.
+  // SigPoints other than PCIn, HSIn, GSIn, PSOut, CSIn, MSIn, MSPOut, ASIn.
+  // According to Vulkan spec, the ClipDistance/CullDistance
+  // BuiltIn can only be used by VSOut, HS/DS/GS In/Out, MSOut.
   case hlsl::Semantic::Kind::ClipDistance:
   case hlsl::Semantic::Kind::CullDistance: {
     switch (sigPointKind) {
@@ -2515,6 +2676,7 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
     case hlsl::SigPoint::Kind::GSVIn:
     case hlsl::SigPoint::Kind::GSOut:
     case hlsl::SigPoint::Kind::PSIn:
+    case hlsl::SigPoint::Kind::MSOut:
       llvm_unreachable("should be handled in gl_PerVertex struct");
     default:
       llvm_unreachable(
@@ -2584,9 +2746,9 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
                                          isPrecise, srcLoc);
   }
   // According to DXIL spec, the PrimitiveID SV can only be used by PCIn, HSIn,
-  // DSIn, GSIn, GSOut, and PSIn.
+  // DSIn, GSIn, GSOut, PSIn, and MSPOut.
   // According to Vulkan spec, the PrimitiveId BuiltIn can only be used in
-  // HS/DS/PS In, GS In/Out.
+  // HS/DS/PS In, GS In/Out, MSPOut.
   case hlsl::Semantic::Kind::PrimitiveID: {
     // Translate to PrimitiveId BuiltIn for all valid SigPoints.
     stageVar->setIsSpirvBuiltin();
@@ -2666,9 +2828,9 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
     return spvBuilder.addStageBuiltinVar(type, sc, bi, isPrecise, srcLoc);
   }
   // According to DXIL spec, the RenderTargetArrayIndex SV can only be used by
-  // VSIn, VSOut, HSCPIn, HSCPOut, DSIn, DSOut, GSVIn, GSOut, PSIn.
-  // According to Vulkan spec, the Layer BuiltIn can only be used in GSOut and
-  // PSIn.
+  // VSIn, VSOut, HSCPIn, HSCPOut, DSIn, DSOut, GSVIn, GSOut, PSIn, MSPOut.
+  // According to Vulkan spec, the Layer BuiltIn can only be used in GSOut
+  // PSIn, and MSPOut.
   case hlsl::Semantic::Kind::RenderTargetArrayIndex: {
     switch (sigPointKind) {
     case hlsl::SigPoint::Kind::VSIn:
@@ -2686,6 +2848,7 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
                                            srcLoc);
     case hlsl::SigPoint::Kind::GSOut:
     case hlsl::SigPoint::Kind::PSIn:
+    case hlsl::SigPoint::Kind::MSPOut:
       stageVar->setIsSpirvBuiltin();
       return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::Layer, isPrecise,
                                            srcLoc);
@@ -2694,9 +2857,9 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
     }
   }
   // According to DXIL spec, the ViewportArrayIndex SV can only be used by
-  // VSIn, VSOut, HSCPIn, HSCPOut, DSIn, DSOut, GSVIn, GSOut, PSIn.
+  // VSIn, VSOut, HSCPIn, HSCPOut, DSIn, DSOut, GSVIn, GSOut, PSIn, MSPOut.
   // According to Vulkan spec, the ViewportIndex BuiltIn can only be used in
-  // GSOut and PSIn.
+  // GSOut, PSIn, and MSPOut.
   case hlsl::Semantic::Kind::ViewPortArrayIndex: {
     switch (sigPointKind) {
     case hlsl::SigPoint::Kind::VSIn:
@@ -2714,6 +2877,7 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
                                            isPrecise, srcLoc);
     case hlsl::SigPoint::Kind::GSOut:
     case hlsl::SigPoint::Kind::PSIn:
+    case hlsl::SigPoint::Kind::MSPOut:
       stageVar->setIsSpirvBuiltin();
       return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::ViewportIndex,
                                            isPrecise, srcLoc);
@@ -2852,6 +3016,7 @@ bool DeclResultIdMapper::validateVKBuiltins(const NamedDecl *decl,
       case hlsl::SigPoint::Kind::GSVIn:
       case hlsl::SigPoint::Kind::GSOut:
       case hlsl::SigPoint::Kind::PSIn:
+      case hlsl::SigPoint::Kind::MSOut:
         break;
       default:
         emitError("PointSize builtin cannot be used as %0", loc)
@@ -2867,9 +3032,20 @@ bool DeclResultIdMapper::validateVKBuiltins(const NamedDecl *decl,
         success = false;
       }
 
-      if (sigPoint->GetKind() != hlsl::SigPoint::Kind::VSIn) {
-        emitError("%0 builtin can only be used in vertex shader input", loc)
-            << builtin;
+      switch (sigPoint->GetKind()) {
+      case hlsl::SigPoint::Kind::VSIn:
+        break;
+      case hlsl::SigPoint::Kind::MSIn:
+      case hlsl::SigPoint::Kind::ASIn:
+        if (builtin != "DrawIndex") {
+          emitError("%0 builtin cannot be used as %1", loc)
+              << builtin << sigPoint->GetName();
+          success = false;
+        }
+        break;
+      default:
+        emitError("%0 builtin cannot be used as %1", loc)
+            << builtin << sigPoint->GetName();
         success = false;
       }
     } else if (builtin == "DeviceIndex") {
@@ -2884,6 +3060,20 @@ bool DeclResultIdMapper::validateVKBuiltins(const NamedDecl *decl,
             << builtin;
         success = false;
       }
+    } else if (builtin == "ViewportMaskNV") {
+      if (sigPoint->GetKind() != hlsl::SigPoint::Kind::MSPOut) {
+        emitError("%0 builtin can only be used as 'primitives' output in MS",
+                  loc)
+            << builtin;
+        success = false;
+      }
+      if (!declType->isArrayType() ||
+          !declType->getArrayElementTypeNoTypeQual()->isSpecificBuiltinType(
+              BuiltinType::Kind::Int)) {
+        emitError("%0 builtin must be of type array of integers", loc)
+            << builtin;
+        success = false;
+      }
     }
   }
 
@@ -2911,6 +3101,8 @@ DeclResultIdMapper::getStorageClassForSigPoint(const hlsl::SigPoint *sigPoint) {
     case hlsl::DXIL::SigPointKind::HSIn:
     case hlsl::DXIL::SigPointKind::GSIn:
     case hlsl::DXIL::SigPointKind::CSIn:
+    case hlsl::DXIL::SigPointKind::MSIn:
+    case hlsl::DXIL::SigPointKind::ASIn:
       sc = spv::StorageClass::Input;
       break;
     default:
@@ -2918,12 +3110,14 @@ DeclResultIdMapper::getStorageClassForSigPoint(const hlsl::SigPoint *sigPoint) {
     }
     break;
   }
-  case hlsl::DXIL::SignatureKind::PatchConstant: {
+  case hlsl::DXIL::SignatureKind::PatchConstOrPrim: {
     // There are some special cases in HLSL (See docs/dxil.rst):
-    // SignatureKind is "PatchConstant" for PCOut and DSIn.
+    // SignatureKind is "PatchConstOrPrim" for PCOut, MSPOut and DSIn.
     switch (sigPointKind) {
     case hlsl::DXIL::SigPointKind::PCOut:
+    case hlsl::DXIL::SigPointKind::MSPOut:
       // Patch Constant Output (Output of Hull which is passed to Domain).
+      // Mesh Shader per-primitive output attributes.
       sc = spv::StorageClass::Output;
       break;
     case hlsl::DXIL::SigPointKind::DSIn:

+ 20 - 5
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -297,6 +297,13 @@ public:
   SpirvVariable *createRayTracingNVStageVar(spv::StorageClass sc,
                                             const VarDecl *decl);
 
+  bool createPayloadStageVars(const hlsl::SigPoint *sigPoint,
+                              spv::StorageClass sc, const NamedDecl *decl,
+                              bool asInput, QualType type,
+                              const llvm::StringRef namePrefix,
+                              SpirvInstruction **value,
+                              uint32_t payloadMemOffset = 0);
+
   /// \brief Creates a function-scope paramter in the current function and
   /// returns its instruction.
   SpirvFunctionParameter *createFnParam(const ParmVarDecl *param);
@@ -479,6 +486,16 @@ public:
   bool decorateResourceBindings();
 
   bool requiresLegalization() const { return needsLegalization; }
+ 
+  /// \brief Returns the given decl's HLSL semantic information.
+  static SemanticInfo getStageVarSemantic(const NamedDecl *decl);
+
+  /// \brief Returns SPIR-V instruction for given stage var decl.
+  SpirvInstruction *getStageVarInstruction(const DeclaratorDecl *decl) {
+    auto *value = stageVarInstructions.lookup(decl);
+    assert(value);
+    return value;
+  }
 
 private:
   /// \brief Wrapper method to create a fatal error message and report it
@@ -557,9 +574,6 @@ private:
       const DeclContext *decl, int arraySize, ContextUsageKind usageKind,
       llvm::StringRef typeName, llvm::StringRef varName);
 
-  /// Returns the given decl's HLSL semantic information.
-  static SemanticInfo getStageVarSemantic(const NamedDecl *decl);
-
   /// Creates all the stage variables mapped from semantics on the given decl.
   /// Returns true on sucess.
   ///
@@ -774,8 +788,9 @@ DeclResultIdMapper::DeclResultIdMapper(ASTContext &context,
       glPerVertex(context, spirvContext, spirvBuilder) {}
 
 bool DeclResultIdMapper::decorateStageIOLocations() {
-  if (spvContext.isRay()) {
-    // No location assignment for any raytracing stage variables
+  if (spvContext.isRay() || spvContext.isAS()) {
+    // No location assignment for any raytracing stage variables or
+    // amplification shader variables
     return true;
   }
   // Try both input and output even if input location assignment failed

+ 3 - 0
tools/clang/lib/SPIRV/FeatureManager.cpp

@@ -119,6 +119,7 @@ Extension FeatureManager::getExtensionSymbol(llvm::StringRef name) {
             Extension::GOOGLE_user_type)
       .Case("SPV_KHR_post_depth_coverage", Extension::KHR_post_depth_coverage)
       .Case("SPV_NV_ray_tracing", Extension::NV_ray_tracing)
+      .Case("SPV_NV_mesh_shader", Extension::NV_mesh_shader)
       .Default(Extension::Unknown);
 }
 
@@ -156,6 +157,8 @@ const char *FeatureManager::getExtensionName(Extension symbol) {
     return "SPV_GOOGLE_user_type";
   case Extension::NV_ray_tracing:
     return "SPV_NV_ray_tracing";
+  case Extension::NV_mesh_shader:
+    return "SPV_NV_mesh_shader";
   default:
     break;
   }

+ 68 - 41
tools/clang/lib/SPIRV/GlPerVertex.cpp

@@ -117,6 +117,12 @@ bool GlPerVertex::recordGlPerVertexDeclFacts(const DeclaratorDecl *decl,
   if (type->isVoidType())
     return true;
 
+  // Indices or payload mesh shader param objects don't contain any
+  // builtin variables or semantic strings. So early return.
+  if (decl->hasAttr<HLSLIndicesAttr>() || decl->hasAttr<HLSLPayloadAttr>()) {
+    return true;
+  }
+
   return doGlPerVertexFacts(decl, type, asInput);
 }
 
@@ -146,18 +152,16 @@ bool GlPerVertex::doGlPerVertexFacts(const DeclaratorDecl *decl,
       return doGlPerVertexFacts(
           decl, hlsl::GetHLSLInputPatchElementType(baseType), asInput);
     }
-    if (hlsl::IsHLSLOutputPatchType(baseType)) {
-      return doGlPerVertexFacts(
-          decl, hlsl::GetHLSLOutputPatchElementType(baseType), asInput);
-    }
-
-    if (hlsl::IsHLSLStreamOutputType(baseType)) {
+    if (hlsl::IsHLSLOutputPatchType(baseType) ||
+        hlsl::IsHLSLStreamOutputType(baseType)) {
       return doGlPerVertexFacts(
           decl, hlsl::GetHLSLOutputPatchElementType(baseType), asInput);
     }
-    if (hasGSPrimitiveTypeQualifier(decl)) {
-      // GS inputs have an additional arrayness that we should remove to check
-      // the underlying type instead.
+    if (hasGSPrimitiveTypeQualifier(decl) ||
+        decl->hasAttr<HLSLVerticesAttr>() ||
+        decl->hasAttr<HLSLPrimitivesAttr>()) {
+      // GS inputs and MS output attribute have an additional arrayness that we
+      // should remove to check the underlying type instead.
       baseType = astContext.getAsConstantArrayType(baseType)->getElementType();
       return doGlPerVertexFacts(decl, baseType, asInput);
     }
@@ -169,7 +173,7 @@ bool GlPerVertex::doGlPerVertexFacts(const DeclaratorDecl *decl,
     return false;
   }
 
-  // Semantic string is attched to this decl directly
+  // Semantic string is attached to this decl directly
 
   // Select the corresponding data member to update
   SemanticIndexToTypeMap *typeMap = nullptr;
@@ -235,7 +239,8 @@ bool GlPerVertex::doGlPerVertexFacts(const DeclaratorDecl *decl,
   }
 
   if (baseType->isConstantArrayType()) {
-    if (spvContext.isHS() || spvContext.isDS() || spvContext.isGS()) {
+    if (spvContext.isHS() || spvContext.isDS() || spvContext.isGS() ||
+        spvContext.isMS()) {
       // Ignore the outermost arrayness and check the inner type to be
       // (vector of) floats
 
@@ -355,11 +360,14 @@ bool GlPerVertex::tryToAccess(hlsl::SigPoint::Kind sigPointKind,
                               uint32_t semanticIndex,
                               llvm::Optional<SpirvInstruction *> invocationId,
                               SpirvInstruction **value, bool noWriteBack,
+                              SpirvInstruction *vecComponent,
                               SourceLocation loc) {
   assert(value);
   // invocationId should only be used for HSPCOut.
-  assert(invocationId.hasValue() ? sigPointKind == hlsl::SigPoint::Kind::HSCPOut
-                                 : true);
+  assert(invocationId.hasValue()
+             ? (sigPointKind == hlsl::SigPoint::Kind::HSCPOut ||
+                sigPointKind == hlsl::SigPoint::Kind::MSOut)
+             : true);
 
   switch (semanticKind) {
   case hlsl::Semantic::Kind::ClipDistance:
@@ -381,10 +389,12 @@ bool GlPerVertex::tryToAccess(hlsl::SigPoint::Kind sigPointKind,
   case hlsl::SigPoint::Kind::VSOut:
   case hlsl::SigPoint::Kind::HSCPOut:
   case hlsl::SigPoint::Kind::DSOut:
+  case hlsl::SigPoint::Kind::MSOut:
     if (noWriteBack)
       return true;
 
-    return writeField(semanticKind, semanticIndex, invocationId, value, loc);
+    return writeField(semanticKind, semanticIndex, invocationId, value,
+                      vecComponent, loc);
   default:
     // Only interfaces that involve gl_PerVertex are needed.
     break;
@@ -526,7 +536,7 @@ bool GlPerVertex::readField(hlsl::Semantic::Kind semanticKind,
 
 void GlPerVertex::writeClipCullArrayFromType(
     llvm::Optional<SpirvInstruction *> invocationId, bool isClip,
-    uint32_t offset, QualType fromType, SpirvInstruction *fromValue,
+    SpirvInstruction *offset, QualType fromType, SpirvInstruction *fromValue,
     SourceLocation loc) const {
   auto *clipCullVar = isClip ? outClipVar : outCullVar;
 
@@ -542,10 +552,8 @@ void GlPerVertex::writeClipCullArrayFromType(
     uint32_t count = {};
 
     if (isScalarType(fromType)) {
-      auto *constant = spvBuilder.getConstantInt(astContext.UnsignedIntTy,
-                                                 llvm::APInt(32, offset));
       auto *ptr =
-          spvBuilder.createAccessChain(f32Type, clipCullVar, {constant}, loc);
+          spvBuilder.createAccessChain(f32Type, clipCullVar, {offset}, loc);
       spvBuilder.createStore(ptr, fromValue, loc);
       return;
     }
@@ -556,9 +564,12 @@ void GlPerVertex::writeClipCullArrayFromType(
       for (uint32_t i = 0; i < count; ++i) {
         // Write elements sequentially into the float array
         auto *constant = spvBuilder.getConstantInt(astContext.UnsignedIntTy,
-                                                   llvm::APInt(32, offset + i));
-        auto *ptr =
-            spvBuilder.createAccessChain(f32Type, clipCullVar, {constant}, loc);
+                                                   llvm::APInt(32, i));
+        auto *ptr = spvBuilder.createAccessChain(
+            f32Type, clipCullVar,
+            {spvBuilder.createBinaryOp(
+                spv::Op::OpIAdd, astContext.UnsignedIntTy, offset, constant)},
+            loc);
         auto *subValue =
             spvBuilder.createCompositeExtract(f32Type, fromValue, {i}, loc);
         spvBuilder.createStore(ptr, subValue, loc);
@@ -571,8 +582,8 @@ void GlPerVertex::writeClipCullArrayFromType(
     return;
   }
 
-  // Writing to an array only happens in HSCPOut.
-  assert(spvContext.isHS());
+  // Writing to an array only happens in HSCPOut or MSOut.
+  assert(spvContext.isHS() || spvContext.isMS());
   // And we are only writing to the array element with InvocationId as index.
   assert(invocationId.hasValue());
 
@@ -588,11 +599,8 @@ void GlPerVertex::writeClipCullArrayFromType(
   uint32_t count = {};
 
   if (isScalarType(fromType)) {
-    auto *ptr = spvBuilder.createAccessChain(
-        f32Type, clipCullVar,
-        {arrayIndex, spvBuilder.getConstantInt(astContext.UnsignedIntTy,
-                                               llvm::APInt(32, offset))},
-        loc);
+    auto *ptr = spvBuilder.createAccessChain(f32Type, clipCullVar,
+                                             {arrayIndex, offset}, loc);
     spvBuilder.createStore(ptr, fromValue, loc);
     return;
   }
@@ -605,8 +613,10 @@ void GlPerVertex::writeClipCullArrayFromType(
           // Block array index
           {arrayIndex,
            // Write elements sequentially into the float array
-           spvBuilder.getConstantInt(astContext.UnsignedIntTy,
-                                     llvm::APInt(32, offset + i))},
+           spvBuilder.createBinaryOp(
+               spv::Op::OpIAdd, astContext.UnsignedIntTy, offset,
+               spvBuilder.getConstantInt(astContext.UnsignedIntTy,
+                                         llvm::APInt(32, i)))},
           loc);
 
       auto *subValue =
@@ -623,7 +633,9 @@ void GlPerVertex::writeClipCullArrayFromType(
 bool GlPerVertex::writeField(hlsl::Semantic::Kind semanticKind,
                              uint32_t semanticIndex,
                              llvm::Optional<SpirvInstruction *> invocationId,
-                             SpirvInstruction **value, SourceLocation loc) {
+                             SpirvInstruction **value,
+                             SpirvInstruction *vecComponent,
+                             SourceLocation loc) {
   // Similar to the writing logic in DeclResultIdMapper::createStageVars():
   //
   // Unlike reading, which may require us to read stand-alone builtins and
@@ -636,9 +648,13 @@ bool GlPerVertex::writeField(hlsl::Semantic::Kind semanticKind,
   //       but we only write to the struct at the InvocationID index
   // * DS: output is a single struct, without extra arrayness
   // * GS: output is controlled by OpEmitVertex, one vertex per time
+  // * MS: output is an array of structs, with extra arrayness
   //
   // The interesting shader stage is HS. We need the InvocationID to write
   // out the value to the correct array element.
+  SpirvInstruction *offset = nullptr;
+  QualType type;
+  bool isClip = false;
   switch (semanticKind) {
   case hlsl::Semantic::Kind::ClipDistance: {
     const auto offsetIter = outClipOffset.find(semanticIndex);
@@ -646,10 +662,11 @@ bool GlPerVertex::writeField(hlsl::Semantic::Kind semanticKind,
     // We should have recorded all these semantics before.
     assert(offsetIter != outClipOffset.end());
     assert(typeIter != outClipType.end());
-    writeClipCullArrayFromType(invocationId, /*isClip=*/true,
-                               offsetIter->second, typeIter->second, *value,
-                               loc);
-    return true;
+    offset = spvBuilder.getConstantInt(astContext.UnsignedIntTy,
+                                       llvm::APInt(32, offsetIter->second));
+    type = typeIter->second;
+    isClip = true;
+    break;
   }
   case hlsl::Semantic::Kind::CullDistance: {
     const auto offsetIter = outCullOffset.find(semanticIndex);
@@ -657,16 +674,26 @@ bool GlPerVertex::writeField(hlsl::Semantic::Kind semanticKind,
     // We should have recorded all these semantics before.
     assert(offsetIter != outCullOffset.end());
     assert(typeIter != outCullType.end());
-    writeClipCullArrayFromType(invocationId, /*isClip=*/false,
-                               offsetIter->second, typeIter->second, *value,
-                               loc);
-    return true;
+    offset = spvBuilder.getConstantInt(astContext.UnsignedIntTy,
+                                       llvm::APInt(32, offsetIter->second));
+    type = typeIter->second;
+    break;
   }
   default:
     // Only Cull or Clip apply.
-    break;
+    return false;
   }
-  return false;
+  if (vecComponent) {
+    QualType elemType;
+    if (!isVectorType(type, &elemType)) {
+      assert(false && "expected vector type");
+    }
+    type = elemType;
+    offset = spvBuilder.createBinaryOp(
+        spv::Op::OpIAdd, astContext.UnsignedIntTy, vecComponent, offset);
+  }
+  writeClipCullArrayFromType(invocationId, isClip, offset, type, *value, loc);
+  return true;
 }
 
 } // end namespace spirv

+ 5 - 4
tools/clang/lib/SPIRV/GlPerVertex.h

@@ -84,7 +84,7 @@ public:
                    uint32_t semanticIndex,
                    llvm::Optional<SpirvInstruction *> invocation,
                    SpirvInstruction **value, bool noWriteBack,
-                   SourceLocation loc);
+                   SpirvInstruction *vecComponent, SourceLocation loc);
 
 private:
   template <unsigned N>
@@ -114,13 +114,14 @@ private:
   /// generated to make sure type correctness.
   void
   writeClipCullArrayFromType(llvm::Optional<SpirvInstruction *> invocationId,
-                             bool isClip, uint32_t offset, QualType fromType,
-                             SpirvInstruction *fromValue,
+                             bool isClip, SpirvInstruction *offset,
+                             QualType fromType, SpirvInstruction *fromValue,
                              SourceLocation loc) const;
   /// Creates SPIR-V instructions to write a field in gl_PerVertex.
   bool writeField(hlsl::Semantic::Kind semanticKind, uint32_t semanticIndex,
                   llvm::Optional<SpirvInstruction *> invocationId,
-                  SpirvInstruction **value, SourceLocation loc);
+                  SpirvInstruction **value, SpirvInstruction *vecComponent,
+                  SourceLocation loc);
 
   /// Internal implementation for recordClipCullDistanceDecl().
   bool doGlPerVertexFacts(const DeclaratorDecl *decl, QualType type,

+ 17 - 0
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -965,6 +965,23 @@ void SpirvBuilder::decorateNoContraction(SpirvInstruction *target,
   module->addDecoration(decor);
 }
 
+void SpirvBuilder::decoratePerPrimitiveNV(SpirvInstruction *target,
+                                          SourceLocation srcLoc) {
+  auto *decor = new (context)
+      SpirvDecoration(srcLoc, target, spv::Decoration::PerPrimitiveNV);
+  module->addDecoration(decor);
+}
+
+void SpirvBuilder::decoratePerTaskNV(SpirvInstruction *target, uint32_t offset,
+                                     SourceLocation srcLoc) {
+  auto *decor =
+      new (context) SpirvDecoration(srcLoc, target, spv::Decoration::PerTaskNV);
+  module->addDecoration(decor);
+  decor = new (context)
+      SpirvDecoration(srcLoc, target, spv::Decoration::Offset, {offset});
+  module->addDecoration(decor);
+}
+
 SpirvConstant *SpirvBuilder::getConstantInt(QualType type, llvm::APInt value,
                                             bool specConst) {
   // We do not reuse existing constant integers. Just create a new one.

+ 528 - 43
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -511,18 +511,22 @@ SpirvEmitter::SpirvEmitter(CompilerInstance &ci)
     spirvOptions.cBufferLayoutRule = SpirvLayoutRule::FxcCTBuffer;
     spirvOptions.tBufferLayoutRule = SpirvLayoutRule::FxcCTBuffer;
     spirvOptions.sBufferLayoutRule = SpirvLayoutRule::FxcSBuffer;
+    spirvOptions.ampPayloadLayoutRule = SpirvLayoutRule::FxcSBuffer;
   } else if (spirvOptions.useGlLayout) {
     spirvOptions.cBufferLayoutRule = SpirvLayoutRule::GLSLStd140;
     spirvOptions.tBufferLayoutRule = SpirvLayoutRule::GLSLStd430;
     spirvOptions.sBufferLayoutRule = SpirvLayoutRule::GLSLStd430;
+    spirvOptions.ampPayloadLayoutRule = SpirvLayoutRule::GLSLStd430;
   } else if (spirvOptions.useScalarLayout) {
     spirvOptions.cBufferLayoutRule = SpirvLayoutRule::Scalar;
     spirvOptions.tBufferLayoutRule = SpirvLayoutRule::Scalar;
     spirvOptions.sBufferLayoutRule = SpirvLayoutRule::Scalar;
+    spirvOptions.ampPayloadLayoutRule = SpirvLayoutRule::Scalar;
   } else {
     spirvOptions.cBufferLayoutRule = SpirvLayoutRule::RelaxedGLSLStd140;
     spirvOptions.tBufferLayoutRule = SpirvLayoutRule::RelaxedGLSLStd430;
     spirvOptions.sBufferLayoutRule = SpirvLayoutRule::RelaxedGLSLStd430;
+    spirvOptions.ampPayloadLayoutRule = SpirvLayoutRule::RelaxedGLSLStd430;
   }
 
   // Set shader module version, source file name, and source file content (if
@@ -4989,6 +4993,11 @@ SpirvEmitter::processAssignment(const Expr *lhs, SpirvInstruction *rhs,
   if (SpirvInstruction *result = tryToAssignToRWBufferRWTexture(lhs, rhs))
     return result;
 
+  // Assigning to a out attribute or indices object in mesh shader should be
+  // handled differently.
+  if (SpirvInstruction *result = tryToAssignToMSOutAttrsOrIndices(lhs, rhs))
+    return result;
+
   // Normal assignment procedure
 
   if (!lhsPtr)
@@ -5833,6 +5842,11 @@ SpirvEmitter::tryToAssignToVectorElements(const Expr *lhs,
       (void)result;
       return rhs; // TODO: incorrect for compound assignments
     } else {
+      // Assigning to one component of mesh out attribute/indices vector object.
+      SpirvInstruction *vecComponent = spvBuilder.getConstantInt(
+          astContext.UnsignedIntTy, llvm::APInt(32, accessor.Swz0));
+      if (tryToAssignToMSOutAttrsOrIndices(base, rhs, vecComponent))
+        return rhs;
       // Assigning to one normal vector component. Nothing special, just fall
       // back to the normal CodeGen path.
       return nullptr;
@@ -5854,6 +5868,26 @@ SpirvEmitter::tryToAssignToVectorElements(const Expr *lhs,
     return processAssignment(base, rhs, false);
   }
 
+  if (tryToAssignToMSOutAttrsOrIndices(base, rhs, /*vecComponent=*/nullptr,
+                                       /*noWriteBack=*/true)) {
+    // Assigning to 'n' components of mesh out attribute/indices vector object.
+    const QualType elemType =
+        hlsl::GetHLSLVecElementType(rhs->getAstResultType());
+    uint32_t i = 0;
+    for (; i < accessor.Count; ++i) {
+      auto *rhsElem = spvBuilder.createCompositeExtract(elemType, rhs, {i},
+                                                        lhs->getLocStart());
+      uint32_t position;
+      accessor.GetPosition(i, &position);
+      SpirvInstruction *vecComponent = spvBuilder.getConstantInt(
+          astContext.UnsignedIntTy, llvm::APInt(32, position));
+      if (!tryToAssignToMSOutAttrsOrIndices(base, rhsElem, vecComponent))
+        break;
+    }
+    assert(i == accessor.Count);
+    return rhs;
+  }
+
   llvm::SmallVector<uint32_t, 4> selectors;
   selectors.resize(baseSize);
   // Assume we are selecting all original elements first.
@@ -5969,6 +6003,191 @@ SpirvEmitter::tryToAssignToMatrixElements(const Expr *lhs,
   return rhs;
 }
 
+SpirvInstruction *SpirvEmitter::tryToAssignToMSOutAttrsOrIndices(
+    const Expr *lhs, SpirvInstruction *rhs, SpirvInstruction *vecComponent,
+    bool noWriteBack) {
+  // Early exit for non-mesh shaders.
+  if (!spvContext.isMS())
+    return nullptr;
+
+  llvm::SmallVector<SpirvInstruction *, 4> indices;
+  bool isMSOutAttribute = false;
+  bool isMSOutAttributeBlock = false;
+  bool isMSOutIndices = false;
+
+  const Expr *base = collectArrayStructIndices(lhs, /*rawIndex*/ false,
+                                               /*rawIndices*/ nullptr, &indices,
+                                               &isMSOutAttribute);
+  // Expecting at least one array index - early exit.
+  if (!base || indices.empty())
+    return nullptr;
+
+  const DeclaratorDecl *varDecl = nullptr;
+  if (isMSOutAttribute) {
+    const MemberExpr *memberExpr = dyn_cast<MemberExpr>(base);
+    assert(memberExpr);
+    varDecl = cast<DeclaratorDecl>(memberExpr->getMemberDecl());
+  } else {
+    if (const auto *arg = dyn_cast<DeclRefExpr>(base)) {
+      if (varDecl = dyn_cast<DeclaratorDecl>(arg->getDecl())) {
+        if (varDecl->hasAttr<HLSLIndicesAttr>()) {
+          isMSOutIndices = true;
+        } else if (varDecl->hasAttr<HLSLVerticesAttr>() ||
+                   varDecl->hasAttr<HLSLPrimitivesAttr>()) {
+          isMSOutAttributeBlock = true;
+        }
+      }
+    }
+  }
+
+  // Return if no out attribute or indices object found.
+  if (!(isMSOutAttribute || isMSOutAttributeBlock || isMSOutIndices)) {
+    return nullptr;
+  }
+
+  // For noWriteBack, return without generating write instructions.
+  if (noWriteBack) {
+    return rhs;
+  }
+
+  // Add vecComponent to indices.
+  if (vecComponent) {
+    indices.push_back(vecComponent);
+  }
+
+  if (isMSOutAttribute) {
+    assignToMSOutAttribute(varDecl, rhs, indices);
+  } else if (isMSOutIndices) {
+    assignToMSOutIndices(varDecl, rhs, indices);
+  } else {
+    assert(isMSOutAttributeBlock);
+    QualType type = varDecl->getType();
+    assert(isa<ConstantArrayType>(type));
+    type = astContext.getAsConstantArrayType(type)->getElementType();
+    assert(type->isStructureType());
+
+    // Extract subvalue and assign to its corresponding member attribute.
+    const auto *structDecl = type->getAs<RecordType>()->getDecl();
+    for (const auto *field : structDecl->fields()) {
+      const auto fieldType = field->getType();
+      SpirvInstruction *subValue = spvBuilder.createCompositeExtract(
+          fieldType, rhs, {getNumBaseClasses(type) + field->getFieldIndex()},
+          lhs->getLocStart());
+      assignToMSOutAttribute(field, subValue, indices);
+    }
+  }
+
+  // TODO: OK, this return value is incorrect for compound assignments, for
+  // which cases we should return lvalues. Should at least emit errors if
+  // this return value is used (can be checked via ASTContext.getParents).
+  return rhs;
+}
+
+void SpirvEmitter::assignToMSOutAttribute(
+    const DeclaratorDecl *decl, SpirvInstruction *value,
+    const llvm::SmallVector<SpirvInstruction *, 4> &indices) {
+  assert(spvContext.isMS() && !indices.empty());
+
+  // Extract attribute index and vecComponent (if any).
+  SpirvInstruction *attrIndex = indices.front();
+  SpirvInstruction *vecComponent = nullptr;
+  if (indices.size() > 1) {
+    vecComponent = indices.back();
+  }
+
+  auto semanticInfo = declIdMapper.getStageVarSemantic(decl);
+  assert(semanticInfo.isValid());
+  const auto loc = decl->getLocation();
+  // Special handle writes to clip/cull distance attributes.
+  if (!declIdMapper.glPerVertex.tryToAccess(
+          hlsl::DXIL::SigPointKind::MSOut, semanticInfo.semantic->GetKind(),
+          semanticInfo.index, attrIndex, &value, /*noWriteBack=*/false,
+          vecComponent, loc)) {
+    // All other attribute writes are handled below.
+    auto *varInstr = declIdMapper.getStageVarInstruction(decl);
+    QualType valueType = value->getAstResultType();
+    varInstr = spvBuilder.createAccessChain(valueType, varInstr, indices, loc);
+    spvBuilder.createStore(varInstr, value, loc);
+  }
+}
+
+void SpirvEmitter::assignToMSOutIndices(
+    const DeclaratorDecl *decl, SpirvInstruction *value,
+    const llvm::SmallVector<SpirvInstruction *, 4> &indices) {
+  assert(spvContext.isMS() && !indices.empty());
+
+  // Extract vertex index and vecComponent (if any).
+  SpirvInstruction *vertIndex = indices.front();
+  SpirvInstruction *vecComponent = nullptr;
+  if (indices.size() > 1) {
+    vecComponent = indices.back();
+  }
+  auto *var = declIdMapper.getStageVarInstruction(decl);
+  const auto *varTypeDecl = astContext.getAsConstantArrayType(decl->getType());
+  QualType varType = varTypeDecl->getElementType();
+  uint32_t numVertices = 1;
+  if (!isVectorType(varType, nullptr, &numVertices)) {
+    assert(isScalarType(varType));
+  }
+  QualType valueType = value->getAstResultType();
+  uint32_t numValues = 1;
+  if (!isVectorType(valueType, nullptr, &numValues)) {
+    assert(isScalarType(valueType));
+  }
+
+  const auto loc = decl->getLocation();
+  if (numVertices == 1) {
+    // for "point" output topology.
+    assert(numValues == 1);
+    // create accesschain for PrimitiveIndicesNV[vertIndex].
+    auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, var,
+                                             vertIndex, loc);
+    // finally create store for PrimitiveIndicesNV[vertIndex] = value.
+    spvBuilder.createStore(ptr, value, loc);
+  } else {
+    // for "line" or "triangle" output topology.
+    assert(numVertices == 2 || numVertices == 3);
+    // set baseOffset = vertIndex * numVertices.
+    auto *baseOffset = spvBuilder.createBinaryOp(
+        spv::Op::OpIMul, astContext.UnsignedIntTy, vertIndex,
+        spvBuilder.getConstantInt(astContext.UnsignedIntTy,
+                                  llvm::APInt(32, numVertices)));
+    if (vecComponent) {
+      // write an individual vector component of uint2 or uint3.
+      assert(numValues == 1);
+      // set baseOffset = baseOffset + vecComponent.
+      baseOffset = spvBuilder.createBinaryOp(
+          spv::Op::OpIAdd, astContext.UnsignedIntTy, baseOffset, vecComponent);
+      // create accesschain for PrimitiveIndicesNV[baseOffset].
+      auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, var,
+                                               baseOffset, loc);
+      // finally create store for PrimitiveIndicesNV[baseOffset] = value.
+      spvBuilder.createStore(ptr, value, loc);
+    } else {
+      // write all vector components of uint2 or uint3.
+      assert(numValues == numVertices);
+      auto *curOffset = baseOffset;
+      for (uint32_t i = 0; i < numValues; ++i) {
+        if (i != 0) {
+          // set curOffset = baseOffset + i.
+          curOffset = spvBuilder.createBinaryOp(
+              spv::Op::OpIAdd, astContext.UnsignedIntTy, baseOffset,
+              spvBuilder.getConstantInt(astContext.UnsignedIntTy,
+                                        llvm::APInt(32, i)));
+        }
+        // create accesschain for PrimitiveIndicesNV[curOffset].
+        auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, var,
+                                                 curOffset, loc);
+        // finally create store for PrimitiveIndicesNV[curOffset] = value[i].
+        spvBuilder.createStore(ptr,
+                               spvBuilder.createCompositeExtract(
+                                   astContext.UnsignedIntTy, value, {i}, loc),
+                               loc);
+      }
+    }
+  }
+}
+
 SpirvInstruction *SpirvEmitter::processEachVectorInMatrix(
     const Expr *matrix, SpirvInstruction *matrixVal,
     llvm::function_ref<SpirvInstruction *(uint32_t, QualType,
@@ -6125,7 +6344,8 @@ SpirvEmitter::processMatrixBinaryOp(const Expr *lhs, const Expr *rhs,
 const Expr *SpirvEmitter::collectArrayStructIndices(
     const Expr *expr, bool rawIndex,
     llvm::SmallVectorImpl<uint32_t> *rawIndices,
-    llvm::SmallVectorImpl<SpirvInstruction *> *indices) {
+    llvm::SmallVectorImpl<SpirvInstruction *> *indices,
+    bool *isMSOutAttribute) {
   assert((rawIndex && rawIndices) || (!rawIndex && indices));
 
   if (const auto *indexing = dyn_cast<MemberExpr>(expr)) {
@@ -6140,7 +6360,20 @@ const Expr *SpirvEmitter::collectArrayStructIndices(
 
     const Expr *base = collectArrayStructIndices(
         indexing->getBase()->IgnoreParenNoopCasts(astContext), rawIndex,
-        rawIndices, indices);
+        rawIndices, indices, isMSOutAttribute);
+
+    if (isMSOutAttribute && base) {
+      if (const auto *arg = dyn_cast<DeclRefExpr>(base)) {
+        if (const auto *varDecl = dyn_cast<VarDecl>(arg->getDecl())) {
+          if (varDecl->hasAttr<HLSLVerticesAttr>() ||
+              varDecl->hasAttr<HLSLPrimitivesAttr>()) {
+            assert(spvContext.isMS());
+            *isMSOutAttribute = true;
+            return expr;
+          }
+        }
+      }
+    }
 
     // Append the index of the current level
     const auto *fieldDecl = cast<FieldDecl>(indexing->getMemberDecl());
@@ -6167,8 +6400,8 @@ const Expr *SpirvEmitter::collectArrayStructIndices(
     // The base of an ArraySubscriptExpr has a wrapping LValueToRValue implicit
     // cast. We need to ingore it to avoid creating OpLoad.
     const Expr *thisBase = indexing->getBase()->IgnoreParenLValueCasts();
-    const Expr *base =
-        collectArrayStructIndices(thisBase, rawIndex, rawIndices, indices);
+    const Expr *base = collectArrayStructIndices(thisBase, rawIndex, rawIndices,
+                                                 indices, isMSOutAttribute);
     indices->push_back(doExpr(indexing->getIdx()));
     return base;
   }
@@ -6188,8 +6421,8 @@ const Expr *SpirvEmitter::collectArrayStructIndices(
           indexing->getArg(0)->IgnoreParenNoopCasts(astContext);
 
       const auto thisBaseType = thisBase->getType();
-      const Expr *base =
-          collectArrayStructIndices(thisBase, rawIndex, rawIndices, indices);
+      const Expr *base = collectArrayStructIndices(
+          thisBase, rawIndex, rawIndices, indices, isMSOutAttribute);
 
       if (thisBaseType != base->getType() &&
           isAKindOfStructuredOrByteBuffer(thisBaseType)) {
@@ -6841,6 +7074,14 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
   case hlsl::IntrinsicOp::IOP_CallShader: {
     processCallShader(callExpr);
     break;
+  }
+  case hlsl::IntrinsicOp::IOP_DispatchMesh: {
+    processDispatchMesh(callExpr);
+    break;
+  }
+  case hlsl::IntrinsicOp::IOP_SetMeshOutputCounts: {
+    processMeshOutputCounts(callExpr);
+    break;
   }
     INTRINSIC_SPIRV_OP_CASE(ddx, DPdx, true);
     INTRINSIC_SPIRV_OP_CASE(ddx_coarse, DPdxCoarse, false);
@@ -9172,6 +9413,7 @@ SpirvInstruction *SpirvEmitter::processReportHit(const CallExpr *callExpr) {
                                           astContext.BoolTy, reportHitArgs,
                                           callExpr->getExprLoc());
 }
+
 void SpirvEmitter::processCallShader(const CallExpr *callExpr) {
   SpirvInstruction *callDataLocInst = nullptr;
   SpirvInstruction *callDataStageVar = nullptr;
@@ -9238,11 +9480,12 @@ void SpirvEmitter::processCallShader(const CallExpr *callExpr) {
   spvBuilder.createStore(callDataArgInst, tempLoad, callExpr->getExprLoc());
   return;
 }
+
 void SpirvEmitter::processTraceRay(const CallExpr *callExpr) {
-  SpirvInstruction *payloadLocInst = nullptr;
-  SpirvInstruction *payloadStageVar = nullptr;
-  const VarDecl *payloadArg = nullptr;
-  QualType payloadType;
+  SpirvInstruction *rayPayloadLocInst = nullptr;
+  SpirvInstruction *rayPayloadStageVar = nullptr;
+  const VarDecl *rayPayloadArg = nullptr;
+  QualType rayPayloadType;
 
   const auto args = callExpr->getArgs();
 
@@ -9252,7 +9495,7 @@ void SpirvEmitter::processTraceRay(const CallExpr *callExpr) {
   }
 
   // HLSL Func
-  // template<typename Payload>
+  // template<typename RayPayload>
   // void TraceRay(RaytracingAccelerationStructure rs,
   //              uint rayflags,
   //              uint InstanceInclusionMask
@@ -9260,36 +9503,36 @@ void SpirvEmitter::processTraceRay(const CallExpr *callExpr) {
   //              uint MultiplierForGeometryContributionToHitGroupIndex,
   //              uint MissShaderIndex,
   //              RayDesc ray,
-  //              inout Payload p)
+  //              inout RayPayload p)
   // where RayDesc = {float3 origin, float tMin, float3 direction, float tMax}
 
   if (const auto *implCastExpr = dyn_cast<CastExpr>(args[7])) {
     if (const auto *arg = dyn_cast<DeclRefExpr>(implCastExpr->getSubExpr())) {
       if (const auto *varDecl = dyn_cast<VarDecl>(arg->getDecl())) {
-        payloadType = varDecl->getType();
-        payloadArg = varDecl;
-        const auto payloadPair = payloadMap.find(payloadType);
-        // Check if same type of payload stage variable was already
+        rayPayloadType = varDecl->getType();
+        rayPayloadArg = varDecl;
+        const auto rayPayloadPair = rayPayloadMap.find(rayPayloadType);
+        // Check if same type of rayPayload stage variable was already
         // created, if so re-use
-        if (payloadPair == payloadMap.end()) {
-          int numPayloadVars = payloadMap.size();
-          payloadStageVar = declIdMapper.createRayTracingNVStageVar(
+        if (rayPayloadPair == rayPayloadMap.end()) {
+          int numPayloadVars = rayPayloadMap.size();
+          rayPayloadStageVar = declIdMapper.createRayTracingNVStageVar(
               spv::StorageClass::RayPayloadNV, varDecl);
           // Decorate unique location id for each created stage var
-          spvBuilder.decorateLocation(payloadStageVar, numPayloadVars);
-          payloadLocInst = spvBuilder.getConstantInt(
+          spvBuilder.decorateLocation(rayPayloadStageVar, numPayloadVars);
+          rayPayloadLocInst = spvBuilder.getConstantInt(
               astContext.UnsignedIntTy, llvm::APInt(32, numPayloadVars));
-          payloadMap[payloadType] =
-              std::make_pair(payloadStageVar, payloadLocInst);
+          rayPayloadMap[rayPayloadType] =
+              std::make_pair(rayPayloadStageVar, rayPayloadLocInst);
         } else {
-          payloadStageVar = payloadPair->second.first;
-          payloadLocInst = payloadPair->second.second;
+          rayPayloadStageVar = rayPayloadPair->second.first;
+          rayPayloadLocInst = rayPayloadPair->second.second;
         }
       }
     }
   }
 
-  assert(payloadStageVar && payloadArg);
+  assert(rayPayloadStageVar && rayPayloadArg);
 
   const auto floatType = astContext.FloatTy;
   const auto vecType = astContext.getExtVectorType(astContext.FloatTy, 3);
@@ -9307,11 +9550,12 @@ void SpirvEmitter::processTraceRay(const CallExpr *callExpr) {
       spvBuilder.createCompositeExtract(floatType, rayDescArg, {3}, loc);
 
   // Copy argument to stage variable
-  const auto payloadArgInst =
-      declIdMapper.getDeclEvalInfo(payloadArg, payloadArg->getLocStart());
-  auto tempLoad = spvBuilder.createLoad(payloadArg->getType(), payloadArgInst,
-                                        payloadArg->getLocStart());
-  spvBuilder.createStore(payloadStageVar, tempLoad, callExpr->getExprLoc());
+  const auto rayPayloadArgInst =
+      declIdMapper.getDeclEvalInfo(rayPayloadArg, rayPayloadArg->getLocStart());
+  auto tempLoad =
+      spvBuilder.createLoad(rayPayloadArg->getType(), rayPayloadArgInst,
+                            rayPayloadArg->getLocStart());
+  spvBuilder.createStore(rayPayloadStageVar, tempLoad, callExpr->getExprLoc());
 
   // SPIR-V Instruction
   // void OpTraceNV ( <id> AccelerationStructureNV acStruct,
@@ -9324,7 +9568,7 @@ void SpirvEmitter::processTraceRay(const CallExpr *callExpr) {
   //                 <id> float Ray Tmin,
   //                 <id> vec3 Ray Direction,
   //                 <id> float Ray Tmax,
-  //                 <id> uint Payload number)
+  //                 <id> uint RayPayload number)
 
   llvm::SmallVector<SpirvInstruction *, 8> traceArgs;
   for (int ii = 0; ii < 6; ii++) {
@@ -9335,18 +9579,78 @@ void SpirvEmitter::processTraceRay(const CallExpr *callExpr) {
   traceArgs.push_back(tMin);
   traceArgs.push_back(direction);
   traceArgs.push_back(tMax);
-  traceArgs.push_back(payloadLocInst);
+  traceArgs.push_back(rayPayloadLocInst);
 
   spvBuilder.createRayTracingOpsNV(spv::Op::OpTraceNV, QualType(), traceArgs,
                                    callExpr->getExprLoc());
 
   // Copy arguments back to stage variable
-  tempLoad = spvBuilder.createLoad(payloadArg->getType(), payloadStageVar,
-                                   payloadArg->getLocStart());
-  spvBuilder.createStore(payloadArgInst, tempLoad, callExpr->getExprLoc());
+  tempLoad = spvBuilder.createLoad(rayPayloadArg->getType(), rayPayloadStageVar,
+                                   rayPayloadArg->getLocStart());
+  spvBuilder.createStore(rayPayloadArgInst, tempLoad, callExpr->getExprLoc());
   return;
 }
 
+void SpirvEmitter::processDispatchMesh(const CallExpr *callExpr) {
+  // HLSL Func - void DispatchMesh(uint ThreadGroupCountX,
+  //                               uint ThreadGroupCountY,
+  //                               uint ThreadGroupCountZ,
+  //                               groupshared <structType> MeshPayload);
+  assert(callExpr->getNumArgs() == 4);
+  const auto args = callExpr->getArgs();
+  const auto loc = callExpr->getExprLoc();
+
+  // 1) create a barrier GroupMemoryBarrierWithGroupSync().
+  processIntrinsicMemoryBarrier(callExpr,
+                                /*isDevice*/ false,
+                                /*groupSync*/ true,
+                                /*isAllBarrier*/ false);
+
+  // 2) set TaskCountNV = threadX * threadY * threadZ.
+  auto *threadX = doExpr(args[0]);
+  auto *threadY = doExpr(args[1]);
+  auto *threadZ = doExpr(args[2]);
+  auto *var = declIdMapper.getBuiltinVar(spv::BuiltIn::TaskCountNV,
+                                         astContext.UnsignedIntTy, loc);
+  auto *taskCount = spvBuilder.createBinaryOp(
+      spv::Op::OpIMul, astContext.UnsignedIntTy, threadX,
+      spvBuilder.createBinaryOp(spv::Op::OpIMul, astContext.UnsignedIntTy,
+                                threadY, threadZ));
+  spvBuilder.createStore(var, taskCount, loc);
+
+  // 3) create PerTaskNV out attribute block and store MeshPayload info.
+  const auto *sigPoint =
+      hlsl::SigPoint::GetSigPoint(hlsl::DXIL::SigPointKind::MSOut);
+  spv::StorageClass sc = spv::StorageClass::Output;
+  auto *payloadArg = doExpr(args[3]);
+  bool isValid = false;
+  if (const auto *implCastExpr = dyn_cast<CastExpr>(args[3])) {
+    if (const auto *arg = dyn_cast<DeclRefExpr>(implCastExpr->getSubExpr())) {
+      if (const auto *paramDecl = dyn_cast<VarDecl>(arg->getDecl())) {
+        if (paramDecl->hasAttr<HLSLGroupSharedAttr>()) {
+          isValid = declIdMapper.createPayloadStageVars(
+              sigPoint, sc, paramDecl, /*asInput=*/false, paramDecl->getType(),
+              "out.var", &payloadArg);
+        }
+      }
+    }
+  }
+  if (!isValid) {
+    emitError("expected groupshared object as argument to DispatchMesh()",
+              args[3]->getExprLoc());
+  }
+}
+
+void SpirvEmitter::processMeshOutputCounts(const CallExpr *callExpr) {
+  // HLSL Func - void SetMeshOutputCounts(uint numVertices, uint numPrimitives);
+  assert(callExpr->getNumArgs() == 2);
+  const auto args = callExpr->getArgs();
+  const auto loc = callExpr->getExprLoc();
+  auto *var = declIdMapper.getBuiltinVar(spv::BuiltIn::PrimitiveCountNV,
+                                         astContext.UnsignedIntTy, loc);
+  spvBuilder.createStore(var, doExpr(args[1]), loc);
+}
+
 SpirvConstant *SpirvEmitter::getValueZero(QualType type) {
   {
     QualType scalarType = {};
@@ -9637,10 +9941,24 @@ hlsl::ShaderModel::Kind SpirvEmitter::getShaderModelKind(StringRef stageName) {
     smk = hlsl::ShaderModel::Kind::Intersection;
     break;
   case 'a':
-    smk = hlsl::ShaderModel::Kind::AnyHit;
+    switch (stageName[1]) {
+    case 'm':
+      smk = hlsl::ShaderModel::Kind::Amplification;
+      break;
+    case 'n':
+      smk = hlsl::ShaderModel::Kind::AnyHit;
+      break;
+    }
     break;
   case 'm':
-    smk = hlsl::ShaderModel::Kind::Miss;
+    switch (stageName[1]) {
+    case 'e':
+      smk = hlsl::ShaderModel::Kind::Mesh;
+      break;
+    case 'i':
+      smk = hlsl::ShaderModel::Kind::Miss;
+      break;
+    }
     break;
   default:
     smk = hlsl::ShaderModel::Kind::Invalid;
@@ -9679,6 +9997,10 @@ SpirvEmitter::getSpirvShaderStage(hlsl::ShaderModel::Kind smk) {
     return spv::ExecutionModel::MissNV;
   case hlsl::ShaderModel::Kind::Callable:
     return spv::ExecutionModel::CallableNV;
+  case hlsl::ShaderModel::Kind::Mesh:
+    return spv::ExecutionModel::MeshNV;
+  case hlsl::ShaderModel::Kind::Amplification:
+    return spv::ExecutionModel::TaskNV;
   default:
     llvm_unreachable("invalid shader model kind");
     break;
@@ -9942,8 +10264,8 @@ bool SpirvEmitter::emitEntryFunctionWrapperForRayTracing(
     paramTypes.push_back(paramType);
 
     // Order of arguments is fixed
-    // Any-Hit/Closest-Hit : Arg 0 = payload(inout), Arg1 = attribute(in)
-    // Miss : Arg 0 = payload(inout)
+    // Any-Hit/Closest-Hit : Arg 0 = rayPayload(inout), Arg1 = attribute(in)
+    // Miss : Arg 0 = rayPayload(inout)
     // Callable : Arg 0 = callable data(inout)
     // Raygeneration/Intersection : No Args allowed
     if (sKind == hlsl::ShaderModel::Kind::RayGeneration) {
@@ -9954,7 +10276,7 @@ bool SpirvEmitter::emitEntryFunctionWrapperForRayTracing(
                sKind == hlsl::ShaderModel::Kind::AnyHit) {
       // Generate rayPayloadInNV and hitAttributeNV stage variables
       if (i == 0) {
-        // First argument is always payload
+        // First argument is always rayPayload
         curStageVar = declIdMapper.createRayTracingNVStageVar(
             spv::StorageClass::IncomingRayPayloadNV, param);
       } else {
@@ -9964,7 +10286,7 @@ bool SpirvEmitter::emitEntryFunctionWrapperForRayTracing(
       }
     } else if (sKind == hlsl::ShaderModel::Kind::Miss) {
       // Generate rayPayloadInNV stage variable
-      // First and only argument is payload
+      // First and only argument is rayPayload
       curStageVar = declIdMapper.createRayTracingNVStageVar(
           spv::StorageClass::IncomingRayPayloadNV, param);
     } else if (sKind == hlsl::ShaderModel::Kind::Callable) {
@@ -10004,6 +10326,166 @@ bool SpirvEmitter::emitEntryFunctionWrapperForRayTracing(
   return true;
 }
 
+bool SpirvEmitter::processMeshOrAmplificationShaderAttributes(
+    const FunctionDecl *decl, uint32_t *outVerticesArraySize) {
+  if (auto *numThreadsAttr = decl->getAttr<HLSLNumThreadsAttr>()) {
+    uint32_t x, y, z;
+    x = static_cast<uint32_t>(numThreadsAttr->getX());
+    y = static_cast<uint32_t>(numThreadsAttr->getY());
+    z = static_cast<uint32_t>(numThreadsAttr->getZ());
+    spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize,
+                                {x, y, z}, decl->getLocation());
+  }
+
+  // Early return for amplification shaders as they only take the 'numthreads'
+  // attribute.
+  if (spvContext.isAS())
+    return true;
+
+  spv::ExecutionMode outputPrimitive = spv::ExecutionMode::Max;
+  if (auto *outputTopology = decl->getAttr<HLSLOutputTopologyAttr>()) {
+    const auto topology = outputTopology->getTopology().lower();
+    outputPrimitive =
+        llvm::StringSwitch<spv::ExecutionMode>(topology)
+            .Case("point", spv::ExecutionMode::OutputPoints)
+            .Case("line", spv::ExecutionMode::OutputLinesNV)
+            .Case("triangle", spv::ExecutionMode::OutputTrianglesNV);
+    if (outputPrimitive != spv::ExecutionMode::Max) {
+      spvBuilder.addExecutionMode(entryFunction, outputPrimitive, {},
+                                  decl->getLocation());
+    } else {
+      emitError("unknown output topology in mesh shader",
+                outputTopology->getLocation());
+      return false;
+    }
+  }
+
+  uint32_t numVertices = 0;
+  uint32_t numIndices = 0;
+  uint32_t numPrimitives = 0;
+  bool payloadDeclSeen = false;
+
+  for (uint32_t i = 0; i < decl->getNumParams(); i++) {
+    const auto param = decl->getParamDecl(i);
+    const auto paramType = param->getType();
+    const auto paramLoc = param->getLocation();
+    if (param->hasAttr<HLSLVerticesAttr>() ||
+        param->hasAttr<HLSLIndicesAttr>() ||
+        param->hasAttr<HLSLPrimitivesAttr>()) {
+      uint32_t arraySize = 0;
+      if (const auto *arrayType =
+              astContext.getAsConstantArrayType(paramType)) {
+        const auto eleType =
+            arrayType->getElementType()->getCanonicalTypeUnqualified();
+        if (param->hasAttr<HLSLIndicesAttr>()) {
+          switch (outputPrimitive) {
+          case spv::ExecutionMode::OutputPoints:
+            if (eleType != astContext.UnsignedIntTy) {
+              emitError("expected 1D array of uint type", paramLoc);
+              return false;
+            }
+            break;
+          case spv::ExecutionMode::OutputLinesNV: {
+            QualType baseType;
+            uint32_t length;
+            if (!isVectorType(eleType, &baseType, &length) ||
+                baseType != astContext.UnsignedIntTy || length != 2) {
+              emitError("expected 1D array of uint2 type", paramLoc);
+              return false;
+            }
+            break;
+          }
+          case spv::ExecutionMode::OutputTrianglesNV: {
+            QualType baseType;
+            uint32_t length;
+            if (!isVectorType(eleType, &baseType, &length) ||
+                baseType != astContext.UnsignedIntTy || length != 3) {
+              emitError("expected 1D array of uint3 type", paramLoc);
+              return false;
+            }
+            break;
+          }
+          default:
+            assert(false && "unexpected spirv execution mode");
+          }
+        } else if (!eleType->isStructureType()) {
+          // vertices/primitives objects
+          emitError("expected 1D array of struct type", paramLoc);
+          return false;
+        }
+        arraySize = static_cast<uint32_t>(arrayType->getSize().getZExtValue());
+      } else {
+        emitError("expected 1D array of indices/vertices/primitives object",
+                  paramLoc);
+        return false;
+      }
+      if (param->hasAttr<HLSLVerticesAttr>()) {
+        if (numVertices != 0) {
+          emitError("only one object with 'vertices' modifier is allowed",
+                    paramLoc);
+          return false;
+        }
+        numVertices = arraySize;
+      } else if (param->hasAttr<HLSLIndicesAttr>()) {
+        if (numIndices != 0) {
+          emitError("only one object with 'indices' modifier is allowed",
+                    paramLoc);
+          return false;
+        }
+        numIndices = arraySize;
+      } else if (param->hasAttr<HLSLPrimitivesAttr>()) {
+        if (numPrimitives != 0) {
+          emitError("only one object with 'primitives' modifier is allowed",
+                    paramLoc);
+          return false;
+        }
+        numPrimitives = arraySize;
+      }
+    } else if (param->hasAttr<HLSLPayloadAttr>()) {
+      if (payloadDeclSeen) {
+        emitError("only one object with 'payload' modifier is allowed",
+                  paramLoc);
+        return false;
+      }
+      payloadDeclSeen = true;
+      if (!paramType->isStructureType()) {
+        emitError("expected payload of struct type", paramLoc);
+        return false;
+      }
+    }
+  }
+
+  // Vertex attribute array is a mandatory param to mesh entry function.
+  if (numVertices != 0) {
+    *outVerticesArraySize = numVertices;
+    spvBuilder.addExecutionMode(
+        entryFunction, spv::ExecutionMode::OutputVertices,
+        {static_cast<uint32_t>(numVertices)}, decl->getLocation());
+  } else {
+    emitError("expected vertices object declaration", decl->getLocation());
+    return false;
+  }
+
+  // Vertex indices array is a mandatory param to mesh entry function.
+  if (numIndices != 0) {
+    spvBuilder.addExecutionMode(
+        entryFunction, spv::ExecutionMode::OutputPrimitivesNV,
+        {static_cast<uint32_t>(numIndices)}, decl->getLocation());
+    // Primitive attribute array is an optional param to mesh entry function,
+    // but the array size should match the indices array.
+    if (numPrimitives != 0 && numPrimitives != numIndices) {
+      emitError("array size of primitives object should match 'indices' object",
+                decl->getLocation());
+      return false;
+    }
+  } else {
+    emitError("expected indices object declaration", decl->getLocation());
+    return false;
+  }
+
+  return true;
+}
+
 bool SpirvEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
                                             SpirvFunction *entryFuncInstr) {
   // HS specific attributes
@@ -10073,6 +10555,9 @@ bool SpirvEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
     if (!processGeometryShaderAttributes(decl, &inputArraySize))
       return false;
     // The per-vertex output of GS is not an array.
+  } else if (spvContext.isMS() || spvContext.isAS()) {
+    if (!processMeshOrAmplificationShaderAttributes(decl, &outputArraySize))
+      return false;
   }
 
   // Go through all parameters and record the declaration of SV_ClipDistance
@@ -10101,7 +10586,7 @@ bool SpirvEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
   // offset of SV_ClipDistance/SV_CullDistance variables within the array.
   declIdMapper.glPerVertex.calculateClipCullDistanceArraySize();
 
-  if (!spvContext.isCS()) {
+  if (!spvContext.isCS() && !spvContext.isAS()) {
     // Generate stand-alone builtins of Position, ClipDistance, and
     // CullDistance, which belongs to gl_PerVertex.
     declIdMapper.glPerVertex.generateVars(inputArraySize, outputArraySize);

+ 33 - 2
tools/clang/lib/SPIRV/SpirvEmitter.h

@@ -260,6 +260,24 @@ private:
   SpirvInstruction *tryToAssignToRWBufferRWTexture(const Expr *lhs,
                                                    SpirvInstruction *rhs);
 
+  /// Tries to emit instructions for assigning to the given mesh out attribute
+  /// or indices object. Returns 0 if the trial fails and no instructions are
+  /// generated.
+  SpirvInstruction *
+  tryToAssignToMSOutAttrsOrIndices(const Expr *lhs, SpirvInstruction *rhs,
+                                   SpirvInstruction *vecComponent = nullptr,
+                                   bool noWriteBack = false);
+
+  /// Emit instructions for assigning to the given mesh out attribute.
+  void assignToMSOutAttribute(
+      const DeclaratorDecl *decl, SpirvInstruction *value,
+      const llvm::SmallVector<SpirvInstruction *, 4> &indices);
+
+  /// Emit instructions for assigning to the given mesh out indices object.
+  void
+  assignToMSOutIndices(const DeclaratorDecl *decl, SpirvInstruction *value,
+                       const llvm::SmallVector<SpirvInstruction *, 4> &indices);
+
   /// Processes each vector within the given matrix by calling actOnEachVector.
   /// matrixVal should be the loaded value of the matrix. actOnEachVector takes
   /// three parameters for the current vector: the index, the <type-id>, and
@@ -294,7 +312,8 @@ private:
   const Expr *
   collectArrayStructIndices(const Expr *expr, bool rawIndex,
                             llvm::SmallVectorImpl<uint32_t> *rawIndices,
-                            llvm::SmallVectorImpl<SpirvInstruction *> *indices);
+                            llvm::SmallVectorImpl<SpirvInstruction *> *indices,
+                            bool *isMSOutAttribute = nullptr);
 
   /// Creates an access chain to index into the given SPIR-V evaluation result
   /// and returns the new SPIR-V evaluation result.
@@ -521,6 +540,12 @@ private:
   void processCallShader(const CallExpr *callExpr);
   void processTraceRay(const CallExpr *callExpr);
 
+  /// Process amplification shader intrinsics.
+  void processDispatchMesh(const CallExpr *callExpr);
+
+  /// Process mesh shader intrinsics.
+  void processMeshOutputCounts(const CallExpr *callExpr);
+
 private:
   /// Returns the <result-id> for constant value 0 of the given type.
   SpirvConstant *getValueZero(QualType type);
@@ -625,6 +650,12 @@ private:
   /// HLSL attributes of the entry point function.
   void processComputeShaderAttributes(const FunctionDecl *entryFunction);
 
+  /// \brief Adds necessary execution modes for the mesh/amplification shader
+  /// based on the HLSL attributes of the entry point function.
+  bool
+  processMeshOrAmplificationShaderAttributes(const FunctionDecl *decl,
+                                             uint32_t *outVerticesArraySize);
+
   /// \brief Emits a wrapper function for the entry function and returns true
   /// on success.
   ///
@@ -1105,7 +1136,7 @@ private:
   /// HitAttributeNV.
   llvm::SmallDenseMap<QualType,
                       std::pair<SpirvInstruction *, SpirvInstruction *>, 4>
-      payloadMap;
+      rayPayloadMap;
   llvm::SmallDenseMap<QualType, SpirvInstruction *, 4> hitAttributeMap;
   llvm::SmallDenseMap<QualType,
                       std::pair<SpirvInstruction *, SpirvInstruction *>, 4>

+ 32 - 0
tools/clang/lib/Sema/SemaExpr.cpp

@@ -608,6 +608,32 @@ static void DiagnoseDirectIsaAccess(Sema &S, const ObjCIvarRefExpr *OIRE,
     }
 }
 
+static bool IsExprAccessingMeshOutArray(Expr* BaseExpr) {
+  switch (BaseExpr->getStmtClass()) {
+  case Stmt::ArraySubscriptExprClass: {
+    ArraySubscriptExpr* ase = cast<ArraySubscriptExpr>(BaseExpr);
+    return IsExprAccessingMeshOutArray(ase->getBase());
+  }
+  case Stmt::ImplicitCastExprClass: {
+    ImplicitCastExpr* ice = cast<ImplicitCastExpr>(BaseExpr);
+    return IsExprAccessingMeshOutArray(ice->getSubExpr());
+  }
+  case Stmt::DeclRefExprClass: {
+    DeclRefExpr* dre = cast<DeclRefExpr>(BaseExpr);
+    ValueDecl* vd = dre->getDecl();
+    if (vd->getAttr<HLSLOutAttr>() &&
+        (vd->getAttr<HLSLIndicesAttr>() ||
+         vd->getAttr<HLSLVerticesAttr>() ||
+         vd->getAttr<HLSLPrimitivesAttr>())) {
+      return true;
+    }
+    return false;
+  }
+  default:
+    return false;
+  }
+}
+
 ExprResult Sema::DefaultLvalueConversion(Expr *E) {
   // Handle any placeholder expressions which made it here.
   if (E->getType()->isPlaceholderType()) {
@@ -666,6 +692,12 @@ ExprResult Sema::DefaultLvalueConversion(Expr *E) {
             dyn_cast<ObjCIvarRefExpr>(E->IgnoreParenCasts()))
     DiagnoseDirectIsaAccess(*this, OIRE, SourceLocation(), /* Expr*/nullptr);
 
+  // check the access to mesh shader output arrays
+  if (isa<ArraySubscriptExpr>(E) && IsExprAccessingMeshOutArray(E)) {
+    Diag(E->getExprLoc(), diag::err_hlsl_load_from_mesh_out_arrays);
+    return ExprError();
+  }
+
   // C++ [conv.lval]p1:
   //   [...] If T is a non-class type, the type of the prvalue is the
   //   cv-unqualified version of T. Otherwise, the type of the

+ 3 - 1
tools/clang/lib/Sema/SemaExprCXX.cpp

@@ -3143,7 +3143,9 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
   case ICK_Lvalue_To_Rvalue: {
     assert(From->getObjectKind() != OK_ObjCProperty);
     ExprResult FromRes = DefaultLvalueConversion(From);
-    assert(!FromRes.isInvalid() && "Can't perform deduced conversion?!");
+    if (FromRes.isInvalid()) {
+      return ExprError();
+    }
     From = FromRes.get();
     FromType = From->getType();
     break;

+ 141 - 8
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -7323,6 +7323,29 @@ VectorMemberAccessError TryParseVectorMemberAccess(_In_z_ const char* memberText
   return VectorMemberAccessError_None;
 }
 
+static bool IsExprAccessingOutIndicesArray(Expr* BaseExpr) {
+  switch(BaseExpr->getStmtClass()) {
+  case Stmt::ArraySubscriptExprClass: {
+    ArraySubscriptExpr* ase = cast<ArraySubscriptExpr>(BaseExpr);
+    return IsExprAccessingOutIndicesArray(ase->getBase());
+  }
+  case Stmt::ImplicitCastExprClass: {
+    ImplicitCastExpr* ice = cast<ImplicitCastExpr>(BaseExpr);
+    return IsExprAccessingOutIndicesArray(ice->getSubExpr());
+  }
+  case Stmt::DeclRefExprClass: {
+    DeclRefExpr* dre = cast<DeclRefExpr>(BaseExpr);
+    ValueDecl* vd = dre->getDecl();
+    if (vd->getAttr<HLSLIndicesAttr>() && vd->getAttr<HLSLOutAttr>()) {
+      return true;
+    }
+    return false;
+  }
+  default:
+    return false;
+  }
+}
+
 bool HLSLExternalSource::LookupVectorMemberExprForHLSL(
     Expr& BaseExpr,
     DeclarationName MemberName,
@@ -7396,6 +7419,14 @@ bool HLSLExternalSource::LookupVectorMemberExprForHLSL(
 
   DXASSERT(positions.IsValid, "otherwise an error should have been returned");
 
+  // Disallow component access for out indices for DXIL path. We still allow
+  // this in SPIR-V path.
+  if (!getSema()->getLangOpts().SPIRV &&
+      IsExprAccessingOutIndicesArray(&BaseExpr) && positions.Count < colCount) {
+    m_sema->Diag(MemberLoc, diag::err_hlsl_out_indices_array_incorrect_access);
+    return false;
+  }
+
   // Consume elements
   QualType resultType;
   if (positions.Count == 1)
@@ -9713,8 +9744,9 @@ void hlsl::DiagnoseTranslationUnit(clang::Sema *self) {
     if (shaderModel->IsGS()) {
       // Validate that GS has the maxvertexcount attribute
       if (!pEntryPointDecl->hasAttr<HLSLMaxVertexCountAttr>()) {
-        self->Diag(pEntryPointDecl->getLocation(),
-                   diag::err_hlsl_missing_maxvertexcount_attr);
+        self->Diag(pEntryPointDecl->getLocation(), diag::err_hlsl_missing_attr)
+            << "GS"
+            << "maxvertexcount";
         return;
       }
     } else if (shaderModel->IsHS()) {
@@ -9731,8 +9763,32 @@ void hlsl::DiagnoseTranslationUnit(clang::Sema *self) {
         }
         pPatchFnDecl = NL.Found;
       } else {
-        self->Diag(pEntryPointDecl->getLocation(),
-                   diag::err_hlsl_missing_patchconstantfunc_attr);
+        self->Diag(pEntryPointDecl->getLocation(), diag::err_hlsl_missing_attr)
+            << "HS"
+            << "patchconstantfunc";
+        return;
+      }
+    } else if (shaderModel->IsMS()) {
+      // Validate that MS has the numthreads attribute
+      if (!pEntryPointDecl->hasAttr<HLSLNumThreadsAttr>()) {
+        self->Diag(pEntryPointDecl->getLocation(), diag::err_hlsl_missing_attr)
+            << "MS"
+            << "numthreads";
+        return;
+      }
+      // Validate that MS has the outputtopology attribute
+      if (!pEntryPointDecl->hasAttr<HLSLOutputTopologyAttr>()) {
+        self->Diag(pEntryPointDecl->getLocation(), diag::err_hlsl_missing_attr)
+            << "MS"
+            << "outputtopology";
+        return;
+      }
+    } else if (shaderModel->IsAS()) {
+      // Validate that AS has the numthreads attribute
+      if (!pEntryPointDecl->hasAttr<HLSLNumThreadsAttr>()) {
+        self->Diag(pEntryPointDecl->getLocation(), diag::err_hlsl_missing_attr)
+            << "AS"
+            << "numthreads";
         return;
       }
     }
@@ -10892,6 +10948,22 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A,
     declAttr = ::new (S.Context) HLSLGloballyCoherentAttr(
         A.getRange(), S.Context, A.getAttributeSpellingListIndex());
     break;
+  case AttributeList::AT_HLSLIndices:
+    declAttr = ::new (S.Context) HLSLIndicesAttr(
+        A.getRange(), S.Context, A.getAttributeSpellingListIndex());
+    break;
+  case AttributeList::AT_HLSLVertices:
+    declAttr = ::new (S.Context) HLSLVerticesAttr(
+        A.getRange(), S.Context, A.getAttributeSpellingListIndex());
+    break;
+  case AttributeList::AT_HLSLPrimitives:
+    declAttr = ::new (S.Context) HLSLPrimitivesAttr(
+        A.getRange(), S.Context, A.getAttributeSpellingListIndex());
+    break;
+  case AttributeList::AT_HLSLPayload:
+    declAttr = ::new (S.Context) HLSLPayloadAttr(
+        A.getRange(), S.Context, A.getAttributeSpellingListIndex());
+    break;
 
   default:
     Handled = false;
@@ -10975,8 +11047,10 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A,
   case AttributeList::AT_HLSLShader:
     declAttr = ::new (S.Context) HLSLShaderAttr(
         A.getRange(), S.Context,
-        ValidateAttributeStringArg(S, A,
-                                   "compute,vertex,pixel,hull,domain,geometry,raygeneration,intersection,anyhit,closesthit,miss,callable"),
+        ValidateAttributeStringArg(
+            S, A,
+            "compute,vertex,pixel,hull,domain,geometry,raygeneration,"
+            "intersection,anyhit,closesthit,miss,callable,mesh,amplification"),
         A.getAttributeSpellingListIndex());
     break;
   case AttributeList::AT_HLSLMaxVertexCount:
@@ -11017,7 +11091,7 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A,
   {
   case AttributeList::AT_VKBuiltIn:
     declAttr = ::new (S.Context) VKBuiltInAttr(A.getRange(), S.Context,
-      ValidateAttributeStringArg(S, A, "PointSize,HelperInvocation,BaseVertex,BaseInstance,DrawIndex,DeviceIndex"),
+      ValidateAttributeStringArg(S, A, "PointSize,HelperInvocation,BaseVertex,BaseInstance,DrawIndex,DeviceIndex,ViewportMaskNV"),
       A.getAttributeSpellingListIndex());
     break;
   case AttributeList::AT_VKLocation:
@@ -11549,7 +11623,8 @@ bool Sema::DiagnoseHLSLDecl(Declarator &D, DeclContext *DC, Expr *BitWidth,
     *pCentroid = nullptr,
     *pCenter = nullptr,
     *pAnyLinear = nullptr,                   // first linear attribute found
-    *pTopology = nullptr;
+    *pTopology = nullptr,
+    *pMeshModifier = nullptr;
   bool usageIn = false;
   bool usageOut = false;
 
@@ -11733,6 +11808,29 @@ bool Sema::DiagnoseHLSLDecl(Declarator &D, DeclContext *DC, Expr *BitWidth,
       }
       break;
 
+    case AttributeList::AT_HLSLIndices:
+    case AttributeList::AT_HLSLVertices:
+    case AttributeList::AT_HLSLPrimitives:
+    case AttributeList::AT_HLSLPayload:
+      if (!(isParameter)) {
+        Diag(pAttr->getLoc(), diag::err_hlsl_varmodifierna)
+          << pAttr->getName() << declarationType << pAttr->getRange();
+        result = false;
+      }
+      if (pMeshModifier) {
+        if (pMeshModifier->getKind() == pAttr->getKind()) {
+          Diag(pAttr->getLoc(), diag::warn_hlsl_duplicate_specifier)
+            << pAttr->getName() << pAttr->getRange();
+        } else {
+          Diag(pAttr->getLoc(), diag::err_hlsl_varmodifiersna)
+            << pAttr->getName() << pMeshModifier->getName()
+            << declarationType << pAttr->getRange();
+          result = false;
+        }
+      }
+      pMeshModifier = pAttr;
+      break;
+
     default:
       break;
     }
@@ -11784,6 +11882,21 @@ bool Sema::DiagnoseHLSLDecl(Declarator &D, DeclContext *DC, Expr *BitWidth,
       result = false;
     }
   }
+  if (pMeshModifier) {
+    if (pMeshModifier->getKind() == AttributeList::Kind::AT_HLSLPayload) {
+      if (!usageIn) {
+        Diag(D.getLocStart(), diag::err_hlsl_missing_in_attr)
+            << pMeshModifier->getName();
+        result = false;
+      }
+    } else {
+      if (!usageOut) {
+        Diag(D.getLocStart(), diag::err_hlsl_missing_out_attr)
+            << pMeshModifier->getName();
+        result = false;
+      }
+    }
+  }
 
   // Validate that stream-ouput objects are marked as inout
   if (isParameter && !(usageIn && usageOut) &&
@@ -12313,6 +12426,22 @@ void hlsl::CustomPrintHLSLAttr(const clang::Attr *A, llvm::raw_ostream &Out, con
     Out << "globallycoherent ";
     break;
 
+  case clang::attr::HLSLIndices:
+    Out << "indices ";
+    break;
+
+  case clang::attr::HLSLVertices:
+    Out << "vertices ";
+    break;
+
+  case clang::attr::HLSLPrimitives:
+    Out << "primitives ";
+    break;
+
+  case clang::attr::HLSLPayload:
+    Out << "payload ";
+    break;
+
   default:
     A->printPretty(Out, Policy);
     break;
@@ -12365,6 +12494,10 @@ bool hlsl::IsHLSLAttr(clang::attr::Kind AttrKind) {
   case clang::attr::HLSLTriangle:
   case clang::attr::HLSLTriangleAdj:
   case clang::attr::HLSLGloballyCoherent:
+  case clang::attr::HLSLIndices:
+  case clang::attr::HLSLVertices:
+  case clang::attr::HLSLPrimitives:
+  case clang::attr::HLSLPayload:
   case clang::attr::NoInline:
   case clang::attr::HLSLExport:
   case clang::attr::VKBinding:

Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 158 - 142
tools/clang/lib/Sema/gen_intrin_main_tables_15.h


+ 1 - 1
tools/clang/test/CodeGenHLSL/batch/expressions/intrinsics/misc/abs1.hlsl

@@ -2,7 +2,7 @@
 
 // CHECK: main
 // After lowering, these would turn into multiple abs calls rather than a 4 x float
-// CHECK: call <4 x float> @"dx.hl.op..<4 x float> (i32, <4 x float>)"(i32 91,
+// CHECK: call <4 x float> @"dx.hl.op..<4 x float> (i32, <4 x float>)"(i32 93,
 
 float4 main(float4 a : A) : SV_TARGET {
   return abs(a*a.yxxx);

+ 22 - 0
tools/clang/test/CodeGenHLSL/mesh-val/amplification.hlsl

@@ -0,0 +1,22 @@
+// RUN: %dxc -E main -T as_6_5 %s | FileCheck %s
+
+// CHECK: dx.op.dispatchMesh.struct.Payload
+
+#define NUM_THREADS 32
+
+struct Payload {
+    float2 dummy;
+    float4 pos;
+    float color[2];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+void main()
+{
+    Payload pld;
+    pld.dummy = float2(1.0,2.0);
+    pld.pos = float4(3.0,4.0,5.0,6.0);
+    pld.color[0] = 7.0;
+    pld.color[1] = 8.0;
+    DispatchMesh(NUM_THREADS, 1, 1, pld);
+}

+ 22 - 0
tools/clang/test/CodeGenHLSL/mesh-val/asOversizePayload.hlsl

@@ -0,0 +1,22 @@
+// RUN: %dxc -E main -T as_6_5 %s | FileCheck %s
+
+// CHECK: payload size is greater than 16384
+
+#define NUM_THREADS 32
+
+struct Payload {
+    float2 dummy;
+    float4 pos[1024];
+    float color[2];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+void main()
+{
+    Payload pld;
+    pld.dummy = float2(1.0,2.0);
+    pld.pos[0] = float4(3.0,4.0,5.0,6.0);
+    pld.color[0] = 7.0;
+    pld.color[1] = 8.0;
+    DispatchMesh(NUM_THREADS, 1, 1, pld);
+}

+ 78 - 0
tools/clang/test/CodeGenHLSL/mesh-val/mesh.hlsl

@@ -0,0 +1,78 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: dx.op.getMeshPayload.struct.MeshPayload
+// CHECK: dx.op.setMeshOutputCounts(i32 168, i32 32, i32 16)
+// CHECK: dx.op.emitIndices
+// CHECK: dx.op.storeVertexOutput
+// CHECK: dx.op.storePrimitiveOutput
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    float alnorm : ALNORM;
+    float ormaln : ORMALN;
+    int layer[6] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor;
+    float alnorm;
+    float ormaln;
+    int layer[6];
+};
+
+groupshared float gsMem[MAX_PRIM];
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint3 primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+    MeshPerVertex ov;
+    if (vid % 2) {
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2);
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor = gsMem[tig / 3 + 1];
+      op.alnorm = mpl.alnorm;
+      op.ormaln = mpl.ormaln;
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl.layer[2];
+      op.layer[3] = mpl.layer[3];
+      op.layer[4] = mpl.layer[4];
+      op.layer[5] = mpl.layer[5];
+      gsMem[tig / 3] = op.normal;
+      prims[tig / 3] = op;
+    }
+    verts[tig] = ov;
+}

+ 21 - 0
tools/clang/test/CodeGenHLSL/mesh-val/missingDispatchMesh.hlsl

@@ -0,0 +1,21 @@
+// RUN: %dxc -E main -T as_6_5 %s | FileCheck %s
+
+// CHECK: DispatchMesh must be called exactly once in an Amplification shader.
+
+#define NUM_THREADS 32
+
+struct Payload {
+    float2 dummy;
+    float4 pos;
+    float color[2];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+void main()
+{
+    Payload pld;
+    pld.dummy = float2(1.0,2.0);
+    pld.pos = float4(3.0,4.0,5.0,6.0);
+    pld.color[0] = 7.0;
+    pld.color[1] = 8.0;
+}

+ 62 - 0
tools/clang/test/CodeGenHLSL/mesh-val/missingSetMeshOutputCounts.hlsl

@@ -0,0 +1,62 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: Missing SetMeshOutputCounts call.
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    int layer[4] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor;
+    int layer[4];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint3 primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    MeshPerVertex ov;
+    if (vid % 2) {
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2);
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor = mpl.malnor;
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl.layer[2];
+      op.layer[3] = mpl.layer[3];
+      prims[tig / 3] = op;
+    }
+    verts[tig] = ov;
+}

+ 63 - 0
tools/clang/test/CodeGenHLSL/mesh-val/msOversizePayload.hlsl

@@ -0,0 +1,63 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: payload size is greater than 16384
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    int layer[4] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor[4096];
+    int layer[4];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint3 primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+    MeshPerVertex ov;
+    if (vid % 2) {
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2);
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor = mpl.malnor[0];
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl.layer[2];
+      op.layer[3] = mpl.layer[3];
+      prims[tig / 3] = op;
+    }
+    verts[tig] = ov;
+}

+ 23 - 0
tools/clang/test/CodeGenHLSL/mesh-val/multipleDispatchMesh.hlsl

@@ -0,0 +1,23 @@
+// RUN: %dxc -E main -T as_6_5 %s | FileCheck %s
+
+// CHECK: DispatchMesh must be called exactly once in an Amplification shader.
+
+#define NUM_THREADS 32
+
+struct Payload {
+    float2 dummy;
+    float4 pos;
+    float color[2];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+void main()
+{
+    Payload pld;
+    pld.dummy = float2(1.0,2.0);
+    pld.pos = float4(3.0,4.0,5.0,6.0);
+    pld.color[0] = 7.0;
+    pld.color[1] = 8.0;
+    DispatchMesh(NUM_THREADS, 1, 1, pld);
+    DispatchMesh(NUM_THREADS, 1, 1, pld);
+}

+ 64 - 0
tools/clang/test/CodeGenHLSL/mesh-val/multipleSetMeshOutputCounts.hlsl

@@ -0,0 +1,64 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: SetMeshOUtputCounts cannot be called multiple times.
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    int layer[4] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor;
+    int layer[4];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint3 primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+    MeshPerVertex ov;
+    if (vid % 2) {
+        SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2);
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor = mpl.malnor;
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl.layer[2];
+      op.layer[3] = mpl.layer[3];
+      prims[tig / 3] = op;
+    }
+    verts[tig] = ov;
+}

+ 24 - 0
tools/clang/test/CodeGenHLSL/mesh-val/nonDominatingDispatchMesh.hlsl

@@ -0,0 +1,24 @@
+// RUN: %dxc -E main -T as_6_5 %s | FileCheck %s
+
+// CHECK: Non-Dominating DispatchMesh call.
+
+#define NUM_THREADS 32
+
+struct Payload {
+    float2 dummy;
+    float4 pos;
+    float color[2];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+void main(in uint tid : SV_DispatchThreadID)
+{
+    Payload pld;
+    pld.dummy = float2(1.0,2.0);
+    pld.pos = float4(3.0,4.0,5.0,6.0);
+    pld.color[0] = 7.0;
+    pld.color[1] = 8.0;
+    if (tid % 2) {
+      DispatchMesh(NUM_THREADS, 1, 1, pld);
+    }
+}

+ 63 - 0
tools/clang/test/CodeGenHLSL/mesh-val/nonDominatingSetMeshOutputCounts.hlsl

@@ -0,0 +1,63 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: Non-Dominating SetMeshOutputCounts call.
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    int layer[4] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor;
+    int layer[4];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint3 primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    MeshPerVertex ov;
+    if (vid % 2) {
+        SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2);
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor = mpl.malnor;
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl.layer[2];
+      op.layer[3] = mpl.layer[3];
+      prims[tig / 3] = op;
+    }
+    verts[tig] = ov;
+}

+ 74 - 0
tools/clang/test/CodeGenHLSL/mesh-val/oversizeSM.hlsl

@@ -0,0 +1,74 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: Total Thread Group Shared Memory storage is 28676, exceeded 28672
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    float alnorm : ALNORM;
+    float ormaln : ORMALN;
+    int layer[6] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor;
+    float alnorm;
+    float ormaln;
+    int layer[6];
+};
+
+groupshared float gsMem[1024 * 7 + 1];
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint3 primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+    MeshPerVertex ov;
+    if (vid % 2) {
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2);
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor = gsMem[tig / 3 + 1];
+      op.alnorm = mpl.alnorm;
+      op.ormaln = mpl.ormaln;
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl.layer[2];
+      op.layer[3] = mpl.layer[3];
+      op.layer[4] = mpl.layer[4];
+      op.layer[5] = mpl.layer[5];
+      gsMem[tig / 3] = op.normal;
+      prims[tig / 3] = op;
+    }
+    verts[tig] = ov;
+}

+ 22 - 0
tools/clang/test/CodeGenHLSL/mesh/amplification.hlsl

@@ -0,0 +1,22 @@
+// RUN: %dxc -E main -T as_6_5 %s | FileCheck %s
+
+// CHECK: dx.op.dispatchMesh.struct.Payload
+
+#define NUM_THREADS 32
+
+struct Payload {
+    float2 dummy;
+    float4 pos;
+    float color[2];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+void main()
+{
+    Payload pld;
+    pld.dummy = float2(1.0,2.0);
+    pld.pos = float4(3.0,4.0,5.0,6.0);
+    pld.color[0] = 7.0;
+    pld.color[1] = 8.0;
+    DispatchMesh(NUM_THREADS, 1, 1, pld);
+}

+ 64 - 0
tools/clang/test/CodeGenHLSL/mesh/illegalOutIndicesAssignment.hlsl

@@ -0,0 +1,64 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: error: a vector in out indices array must be accessed as a whole
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    int layer[4] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor;
+    int layer[4];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint3 primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+    MeshPerVertex ov;
+    if (vid % 2) {
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2);
+      primIndices[tig / 3].x = 0;
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor = mpl.malnor;
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl.layer[2];
+      op.layer[3] = mpl.layer[3];
+      prims[tig / 3] = op;
+    }
+    verts[tig] = ov;
+}

+ 78 - 0
tools/clang/test/CodeGenHLSL/mesh/mesh.hlsl

@@ -0,0 +1,78 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: dx.op.getMeshPayload.struct.MeshPayload
+// CHECK: dx.op.setMeshOutputCounts(i32 168, i32 32, i32 16)
+// CHECK: dx.op.emitIndices
+// CHECK: dx.op.storeVertexOutput
+// CHECK: dx.op.storePrimitiveOutput
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    float alnorm : ALNORM;
+    float ormaln : ORMALN;
+    int layer[6] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor;
+    float alnorm;
+    float ormaln;
+    int layer[6];
+};
+
+groupshared float gsMem[MAX_PRIM];
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint3 primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+    MeshPerVertex ov;
+    if (vid % 2) {
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2);
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor = gsMem[tig / 3 + 1];
+      op.alnorm = mpl.alnorm;
+      op.ormaln = mpl.ormaln;
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl.layer[2];
+      op.layer[3] = mpl.layer[3];
+      op.layer[4] = mpl.layer[4];
+      op.layer[5] = mpl.layer[5];
+      gsMem[tig / 3] = op.normal;
+      prims[tig / 3] = op;
+    }
+    verts[tig] = ov;
+}

+ 64 - 0
tools/clang/test/CodeGenHLSL/mesh/multipleInPayload.hlsl

@@ -0,0 +1,64 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: error: multiple in payload parameters not allowed
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    int layer[4] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor;
+    int layer[4];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint3 primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in payload MeshPayload mpl2,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+    MeshPerVertex ov;
+    if (vid % 2) {
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2);
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor = mpl2.malnor;
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl2.layer[2];
+      op.layer[3] = mpl2.layer[3];
+      prims[tig / 3] = op;
+    }
+    verts[tig] = ov;
+}

+ 65 - 0
tools/clang/test/CodeGenHLSL/mesh/multipleOutIndices.hlsl

@@ -0,0 +1,65 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: error: multiple out indices parameters not allowed
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    int layer[4] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor;
+    int layer[4];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint3 primIndices[MAX_PRIM],
+            out indices uint3 primIndices2[MAX_PRIM],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+    MeshPerVertex ov;
+    if (vid % 2) {
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2);
+      primIndices2[tig / 3] = uint3(tig, tig + 1, tig + 2);
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor = mpl.malnor;
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl.layer[2];
+      op.layer[3] = mpl.layer[3];
+      prims[tig / 3] = op;
+    }
+    verts[tig] = ov;
+}

+ 65 - 0
tools/clang/test/CodeGenHLSL/mesh/multipleOutPrimitives.hlsl

@@ -0,0 +1,65 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: error: multiple out primitives parameters not allowed
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    int layer[4] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor;
+    int layer[4];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint3 primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            out primitives MeshPerPrimitive prims2[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+    MeshPerVertex ov;
+    if (vid % 2) {
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2);
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor = mpl.malnor;
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl.layer[2];
+      op.layer[3] = mpl.layer[3];
+      prims[tig / 3] = op;
+      prims2[tig / 3] = op;
+    }
+    verts[tig] = ov;
+}

+ 65 - 0
tools/clang/test/CodeGenHLSL/mesh/multipleOutVertices.hlsl

@@ -0,0 +1,65 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: error: multiple out vertices parameters not allowed
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    int layer[4] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor;
+    int layer[4];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint3 primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out vertices MeshPerVertex verts2[MAX_VERT],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+    MeshPerVertex ov;
+    if (vid % 2) {
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2);
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor = mpl.malnor;
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl.layer[2];
+      op.layer[3] = mpl.layer[3];
+      prims[tig / 3] = op;
+    }
+    verts[tig] = ov;
+    verts2[tig] = ov;
+}

+ 63 - 0
tools/clang/test/CodeGenHLSL/mesh/notArrayOutIndices.hlsl

@@ -0,0 +1,63 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: error: indices output is not an constant-length array
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    int layer[4] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor;
+    int layer[4];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint primIndices,
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+    MeshPerVertex ov;
+    if (vid % 2) {
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices = tig;
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor = mpl.malnor;
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl.layer[2];
+      op.layer[3] = mpl.layer[3];
+      prims[tig / 3] = op;
+    }
+    verts[tig] = ov;
+}

+ 63 - 0
tools/clang/test/CodeGenHLSL/mesh/notArrayOutPrimitives.hlsl

@@ -0,0 +1,63 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: error: primitives output is not an constant-length array
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    int layer[4] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor;
+    int layer[4];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint3 primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims,
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+    MeshPerVertex ov;
+    if (vid % 2) {
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2);
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor = mpl.malnor;
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl.layer[2];
+      op.layer[3] = mpl.layer[3];
+      prims = op;
+    }
+    verts[tig] = ov;
+}

+ 63 - 0
tools/clang/test/CodeGenHLSL/mesh/notArrayOutVertices.hlsl

@@ -0,0 +1,63 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: error: vertices output is not an constant-length array
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    int layer[4] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor;
+    int layer[4];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint3 primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts,
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+    MeshPerVertex ov;
+    if (vid % 2) {
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2);
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor = mpl.malnor;
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl.layer[2];
+      op.layer[3] = mpl.layer[3];
+      prims[tig / 3] = op;
+    }
+    verts = ov;
+}

+ 63 - 0
tools/clang/test/CodeGenHLSL/mesh/notUint2OutIndicesForLines.hlsl

@@ -0,0 +1,63 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: error: the element of out_indices array in a mesh shader whose output topology is line must be uint2
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    int layer[4] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor;
+    int layer[4];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("line")]
+void main(
+            out indices uint3 primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+    MeshPerVertex ov;
+    if (vid % 2) {
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2);
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor = mpl.malnor;
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl.layer[2];
+      op.layer[3] = mpl.layer[3];
+      prims[tig / 3] = op;
+    }
+    verts[tig] = ov;
+}

+ 63 - 0
tools/clang/test/CodeGenHLSL/mesh/notUint3OutIndicesForTriangles.hlsl

@@ -0,0 +1,63 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: error: the element of out_indices array in a mesh shader whose output topology is triangle must be uint3
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    int layer[4] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor;
+    int layer[4];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint2 primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+    MeshPerVertex ov;
+    if (vid % 2) {
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = uint2(tig, tig + 1);
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor = mpl.malnor;
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl.layer[2];
+      op.layer[3] = mpl.layer[3];
+      prims[tig / 3] = op;
+    }
+    verts[tig] = ov;
+}

+ 63 - 0
tools/clang/test/CodeGenHLSL/mesh/notUintOutIndices.hlsl

@@ -0,0 +1,63 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: error: the element of out_indices array must be uint2 for line output or uint3 for triangle output
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    int layer[4] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor;
+    int layer[4];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices float3 primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+    MeshPerVertex ov;
+    if (vid % 2) {
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = float3(tig, tig + 1, tig + 2);
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor = mpl.malnor;
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl.layer[2];
+      op.layer[3] = mpl.layer[3];
+      prims[tig / 3] = op;
+    }
+    verts[tig] = ov;
+}

+ 63 - 0
tools/clang/test/CodeGenHLSL/mesh/notVectorOutIndices.hlsl

@@ -0,0 +1,63 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: error: the element of out_indices array must be uint2 for line output or uint3 for triangle output
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    int layer[4] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor;
+    int layer[4];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+    MeshPerVertex ov;
+    if (vid % 2) {
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = tig;
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor = mpl.malnor;
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl.layer[2];
+      op.layer[3] = mpl.layer[3];
+      prims[tig / 3] = op;
+    }
+    verts[tig] = ov;
+}

+ 64 - 0
tools/clang/test/CodeGenHLSL/mesh/readFromOutIndices.hlsl

@@ -0,0 +1,64 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: error: output arrays of a mesh shader can not be read from
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    int layer[4] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor;
+    int layer[4];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint3 primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+    MeshPerVertex ov;
+    if (vid % 2) {
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2);
+      primIndices[tig / 3 + 1] = primIndices[tig / 3];
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor = mpl.malnor;
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl.layer[2];
+      op.layer[3] = mpl.layer[3];
+      prims[tig / 3] = op;
+    }
+    verts[tig] = ov;
+}

+ 63 - 0
tools/clang/test/CodeGenHLSL/mesh/tooManyOutIndices.hlsl

@@ -0,0 +1,63 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: error: max primitive count should not exceed 256
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    int layer[4] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor;
+    int layer[4];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint3 primIndices[257],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+    MeshPerVertex ov;
+    if (vid % 2) {
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2);
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor = mpl.malnor;
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl.layer[2];
+      op.layer[3] = mpl.layer[3];
+      prims[tig / 3] = op;
+    }
+    verts[tig] = ov;
+}

+ 63 - 0
tools/clang/test/CodeGenHLSL/mesh/tooManyOutPrimitives.hlsl

@@ -0,0 +1,63 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: error: max primitive count should not exceed 256
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    int layer[4] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor;
+    int layer[4];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint3 primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims[257],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+    MeshPerVertex ov;
+    if (vid % 2) {
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2);
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor = mpl.malnor;
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl.layer[2];
+      op.layer[3] = mpl.layer[3];
+      prims[tig / 3] = op;
+    }
+    verts[tig] = ov;
+}

+ 63 - 0
tools/clang/test/CodeGenHLSL/mesh/tooManyOutVertices.hlsl

@@ -0,0 +1,63 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: error: max vertex count should not exceed 256
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    int layer[4] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor;
+    int layer[4];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint3 primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts[257],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+    MeshPerVertex ov;
+    if (vid % 2) {
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2);
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor = mpl.malnor;
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl.layer[2];
+      op.layer[3] = mpl.layer[3];
+      prims[tig / 3] = op;
+    }
+    verts[tig] = ov;
+}

+ 62 - 0
tools/clang/test/CodeGenSPIRV/meshshading.nv.amplification.hlsl

@@ -0,0 +1,62 @@
+// Run: %dxc -T as_6_5 -E main
+// CHECK:  OpCapability MeshShadingNV
+// CHECK:  OpExtension "SPV_NV_mesh_shader"
+// CHECK:  OpEntryPoint TaskNV %main "main" [[drawid:%\d+]] %gl_LocalInvocationID %gl_WorkGroupID %gl_GlobalInvocationID %gl_LocalInvocationIndex [[taskcount:%\d+]] %out_var_dummy %out_var_pos
+// CHECK:  OpExecutionMode %main LocalSize 128 1 1
+
+// CHECK:  OpDecorate [[drawid]] BuiltIn DrawIndex
+// CHECK:  OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId
+// CHECK:  OpDecorate %gl_WorkGroupID BuiltIn WorkgroupId
+// CHECK:  OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId
+// CHECK:  OpDecorate %gl_LocalInvocationIndex BuiltIn LocalInvocationIndex
+// CHECK:  OpDecorate [[taskcount]] BuiltIn TaskCountNV
+
+struct MeshPayload {
+// CHECK:  OpDecorate %out_var_dummy PerTaskNV
+// CHECK:  OpDecorate %out_var_dummy Offset 0
+// CHECK:  OpDecorate %out_var_pos PerTaskNV
+// CHECK:  OpDecorate %out_var_pos Offset 48
+    float dummy[10];
+    float4 pos;
+};
+
+// CHECK:  %pld = OpVariable %_ptr_Workgroup_MeshPayload Workgroup
+groupshared MeshPayload pld;
+
+// CHECK:  %gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input
+// CHECK:  %gl_LocalInvocationIndex = OpVariable %_ptr_Input_uint Input
+// CHECK:  [[taskcount]] = OpVariable %_ptr_Output_uint Output
+// CHECK:  %out_var_dummy = OpVariable %_ptr_Output__arr_float_uint_10 Output
+// CHECK:  %out_var_pos = OpVariable %_ptr_Output_v4float Output
+
+#define NUM_THREADS 128
+
+[numthreads(NUM_THREADS, 1, 1)]
+void main(
+// CHECK:  %param_var_drawId = OpVariable %_ptr_Function_int Function
+// CHECK:  %param_var_gtid = OpVariable %_ptr_Function_v3uint Function
+// CHECK:  %param_var_gid = OpVariable %_ptr_Function_v2uint Function
+// CHECK:  %param_var_tid = OpVariable %_ptr_Function_uint Function
+// CHECK:  %param_var_tig = OpVariable %_ptr_Function_uint Function
+        [[vk::builtin("DrawIndex")]] in int drawId : DRAW,  // -> BuiltIn DrawIndex
+        in uint3 gtid : SV_GroupThreadID,
+        in uint2 gid : SV_GroupID,
+        in uint tid : SV_DispatchThreadID,
+        in uint tig : SV_GroupIndex)
+{
+
+// CHECK:  [[a:%\d+]] = OpAccessChain %_ptr_Workgroup_v4float %pld %int_1
+// CHECK:  OpStore [[a]] {{%\d+}}
+    pld.pos = float4(3.0,4.0,5.0,6.0);
+
+// CHECK:  OpControlBarrier %uint_2 %uint_2 %uint_264
+// CHECK:  [[c:%\d+]] = OpIMul %uint %uint_1 %uint_1
+// CHECK:  [[d:%\d+]] = OpIMul %uint %uint_128 [[c]]
+// CHECK:  OpStore [[taskcount]] [[d]]
+// CHECK:  [[e:%\d+]] = OpLoad %MeshPayload %pld
+// CHECK:  [[f:%\d+]] = OpCompositeExtract %_arr_float_uint_10 [[e]] 0
+// CHECK:  OpStore %out_var_dummy [[f]]
+// CHECK:  [[g:%\d+]] = OpCompositeExtract %v4float [[e]] 1
+// CHECK:  OpStore %out_var_pos [[g]]
+    DispatchMesh(NUM_THREADS, 1, 1, pld);
+}

+ 14 - 0
tools/clang/test/CodeGenSPIRV/meshshading.nv.error1.amplification.hlsl

@@ -0,0 +1,14 @@
+// Run: %dxc -T as_6_5 -E main
+
+// CHECK:  11:6: error: AS entry point must have the numthreads attribute
+
+struct MeshPayload {
+    float4 pos;
+};
+
+groupshared MeshPayload pld;
+
+void main(
+        in uint tig : SV_GroupIndex)
+{
+}

+ 19 - 0
tools/clang/test/CodeGenSPIRV/meshshading.nv.error1.mesh.hlsl

@@ -0,0 +1,19 @@
+// Run: %dxc -T ms_6_5 -E main
+
+// CHECK: 14:6: error: MS entry point must have the outputtopology attribute
+
+struct MeshPerVertex {
+    float4 position : SV_Position;
+};
+
+#define MAX_VERT 64
+#define MAX_PRIM 81
+#define NUM_THREADS 128
+
+[numthreads(NUM_THREADS, 1, 1)]
+void main(
+        out vertices MeshPerVertex verts[MAX_VERT],
+        out indices uint3 primitiveInd[MAX_PRIM],
+        in uint tig : SV_GroupIndex)
+{
+}

Kaikkia tiedostoja ei voida näyttää, sillä liian monta tiedostoa muuttui tässä diffissä