Explorar o código

[SPIR-V] Add support for the extension VK_EXT_mesh_shader (#4725)

* Support VK_EXT_mesh_shader

* Fix errors when compiling with SPV_NV_mesh_shader

* Minor tweak onto amplification shaders

* Fix pre-checkin failure

* Add amplification and mesh tests for EXT_mesh_shader

* Add some comments

* Fix primitive indices

Co-authored-by: Tianyuan <[email protected]>
TY-AMD %!s(int64=2) %!d(string=hai) anos
pai
achega
97bafd2a60

+ 59 - 23
docs/SPIR-V.rst

@@ -294,6 +294,7 @@ Supported extensions
 * SPV_KHR_shader_draw_parameters
 * SPV_EXT_descriptor_indexing
 * SPV_EXT_fragment_fully_covered
+* SPV_EXT_mesh_shader
 * SPV_EXT_shader_stencil_support
 * SPV_AMD_shader_early_and_late_fragment_tests
 * SPV_AMD_shader_explicit_vertex_parameter
@@ -1522,13 +1523,15 @@ some system-value (SV) semantic strings will be translated into SPIR-V
 |                           +-------------+----------------------------------------+-----------------------+-----------------------------+
 |                           | DsIn        | ``PrimitiveId``                        | N/A                   | ``Tessellation``            |
 |                           +-------------+----------------------------------------+-----------------------+-----------------------------+
-| SV_PrimitiveID            | GSIn        | ``PrimitiveId``                        | N/A                   | ``Geometry``                |
-|                           +-------------+----------------------------------------+-----------------------+-----------------------------+
+|                           | GSIn        | ``PrimitiveId``                        | N/A                   | ``Geometry``                |
+| SV_PrimitiveID            +-------------+----------------------------------------+-----------------------+-----------------------------+
 |                           | GSOut       | ``PrimitiveId``                        | N/A                   | ``Geometry``                |
 |                           +-------------+----------------------------------------+-----------------------+-----------------------------+
 |                           | PSIn        | ``PrimitiveId``                        | N/A                   | ``Geometry``                |
 |                           +-------------+----------------------------------------+-----------------------+-----------------------------+
-|                           | MSOut       | ``PrimitiveId``                        | N/A                   | ``MeshShadingNV``           |
+|                           |             |                                        |                       | ``MeshShadingNV``           |
+|                           | MSOut       | ``PrimitiveId``                        | N/A                   |                             |
+|                           |             |                                        |                       | ``MeshShadingEXT``          |
 +---------------------------+-------------+----------------------------------------+-----------------------+-----------------------------+
 |                           | PCOut       | ``TessLevelOuter``                     | N/A                   | ``Tessellation``            |
 | SV_TessFactor             +-------------+----------------------------------------+-----------------------+-----------------------------+
@@ -1546,15 +1549,19 @@ some system-value (SV) semantic strings will be translated into SPIR-V
 +---------------------------+-------------+----------------------------------------+-----------------------+-----------------------------+
 |                           | GSOut       | ``Layer``                              | N/A                   | ``Geometry``                |
 |                           +-------------+----------------------------------------+-----------------------+-----------------------------+
-| SV_RenderTargetArrayIndex | PSIn        | ``Layer``                              | N/A                   | ``Geometry``                |
-|                           +-------------+----------------------------------------+-----------------------+-----------------------------+
-|                           | MSOut       | ``Layer``                              | N/A                   | ``MeshShadingNV``           |
+|                           | PSIn        | ``Layer``                              | N/A                   | ``Geometry``                |
+| SV_RenderTargetArrayIndex +-------------+----------------------------------------+-----------------------+-----------------------------+
+|                           |             |                                        |                       | ``MeshShadingNV``           |
+|                           | MSOut       | ``Layer``                              | N/A                   |                             |
+|                           |             |                                        |                       | ``MeshShadingEXT``          |
 +---------------------------+-------------+----------------------------------------+-----------------------+-----------------------------+
 |                           | GSOut       | ``ViewportIndex``                      | N/A                   | ``MultiViewport``           |
 |                           +-------------+----------------------------------------+-----------------------+-----------------------------+
-| SV_ViewportArrayIndex     | PSIn        | ``ViewportIndex``                      | N/A                   | ``MultiViewport``           |
-|                           +-------------+----------------------------------------+-----------------------+-----------------------------+
-|                           | MSOut       | ``ViewportIndex``                      | N/A                   | ``MeshShadingNV``           |
+|                           | PSIn        | ``ViewportIndex``                      | N/A                   | ``MultiViewport``           |
+| SV_ViewportArrayIndex     +-------------+----------------------------------------+-----------------------+-----------------------------+
+|                           |             |                                        |                       | ``MeshShadingNV``           |
+|                           | MSOut       | ``ViewportIndex``                      | N/A                   |                             |
+|                           |             |                                        |                       | ``MeshShadingEXT``          |
 +---------------------------+-------------+----------------------------------------+-----------------------+-----------------------------+
 |                           | PSIn        | ``SampleMask``                         | N/A                   | ``Shader``                  |
 | SV_Coverage               +-------------+----------------------------------------+-----------------------+-----------------------------+
@@ -1582,6 +1589,9 @@ some system-value (SV) semantic strings will be translated into SPIR-V
 |                           +-------------+----------------------------------------+-----------------------+-----------------------------+
 |                           | MSOut       | ``PrimitiveShadingRateKHR``            | N/A                   | ``FragmentShadingRate``     |
 +---------------------------+-------------+----------------------------------------+-----------------------+-----------------------------+
+| SV_CullPrimitive          | MSOut       | ``CullPrimitiveEXT``                   | N/A                   | ``MeshShadingEXT ``         |
++---------------------------+-------------+----------------------------------------+-----------------------+-----------------------------+
+
 
 For entities (function parameters, function return values, struct fields) with
 the above SV semantic strings attached, SPIR-V variables of the
@@ -3409,26 +3419,34 @@ shaders and are translated to SPIR-V execution modes according to the table belo
 
 .. table:: Mapping from HLSL attribute to SPIR-V execution mode
 
-+-------------------+--------------------+-------------------------+
-|  HLSL Attribute   |   Value            | SPIR-V Execution Mode   |
-+===================+====================+=========================+
-|``outputtopology`` | ``point``          | ``OutputPoints``        |
-|                   +--------------------+-------------------------+
-|``(Mesh shader)``  | ``line``           | ``OutputLinesNV``       |
-|                   +--------------------+-------------------------+
-|                   | ``triangle``       | ``OutputTrianglesNV``   |
-+-------------------+--------------------+-------------------------+
-| ``numthreads``    | ``X, Y, Z``        | ``LocalSize X, Y, Z``   |
-|                   |                    |                         |
-|                   | ``(X*Y*Z <= 128)`` |                         |
-+-------------------+--------------------+-------------------------+
++-----------------------+--------------------+-------------------------+
+|  HLSL Attribute       |   Value            | SPIR-V Execution Mode   |
++=======================+====================+=========================+
+|``outputtopology``     | ``point``          | ``OutputPoints``        |
+|                       +--------------------+-------------------------+
+| (SPV_NV_mesh_shader)  | ``line``           | ``OutputLinesNV``       |
+|                       |                    |                         |
+|                       +--------------------+-------------------------+
+|                       | ``triangle``       | ``OutputTrianglesNV``   |
++-----------------------+--------------------+-------------------------+
+|``outputtopology``     | ``point``          | ``OutputPoints``        |
+|                       +--------------------+-------------------------+
+| (SPV_EXT_mesh_shader) | ``line``           | ``OutputLinesEXT``      |
+|                       |                    |                         |
+|                       +--------------------+-------------------------+
+|                       | ``triangle``       | ``OutputTrianglesEXT``  |
++-----------------------+--------------------+-------------------------+
+| ``numthreads``        | ``X, Y, Z``        | ``LocalSize X, Y, Z``   |
+|                       |                    |                         |
+|                       | ``(X*Y*Z <= 128)`` |                         |
++-----------------------+--------------------+-------------------------+
 
 Intrinsics
 ~~~~~~~~~~
 The following HLSL intrinsics are used in Mesh or Amplification shaders
 and are translated to SPIR-V intrinsics according to the table below:
 
-.. table:: Mapping from HLSL intrinsics to SPIR-V intrinsics
+.. table:: Mapping from HLSL intrinsics to SPIR-V intrinsics for SPV_NV_mesh_shader
 
 +---------------------------+--------------------+-----------------------------------------+
 |  HLSL Intrinsic           |  Parameters        | SPIR-V Intrinsic                        |
@@ -3446,6 +3464,24 @@ and are translated to SPIR-V intrinsics according to the table below:
 |                           | ``MeshPayload``    |                                         |
 +---------------------------+--------------------+-----------------------------------------+
 
+.. table:: Mapping from HLSL intrinsics to SPIR-V intrinsics for SPV_EXT_mesh_shader
+
++---------------------------+--------------------+--------------------------------------------------------------+
+|  HLSL Intrinsic           |  Parameters        | SPIR-V Intrinsic                                             |
++===========================+====================+==============================================================+
+| ``SetMeshOutputCounts``   | ``numVertices``    | ``OpSetMeshOutputsEXT``                                      |
+|                           |                    |                                                              |
+| ``(Mesh shader)``         | ``numPrimitives``  |                                                              |
++---------------------------+--------------------+--------------------------------------------------------------+
+| ``DispatchMesh``          | ``ThreadX``        | ``OpEmitMeshTasksEXT ThreadX ThreadY ThreadZ MeshPayload``   |
+|                           |                    |                                                              |
+| ``(Amplification shader)``| ``ThreadY``        | ``TaskCountNV ThreadX*ThreadY*ThreadZ``                      |
+|                           |                    |                                                              |
+|                           | ``ThreadZ``        |                                                              |
+|                           |                    |                                                              |
+|                           | ``MeshPayload``    |                                                              |
++---------------------------+--------------------+--------------------------------------------------------------+
+
 | Note : For ``DispatchMesh`` intrinsic, we also emit ``MeshPayload`` as output block with ``PerTaskNV`` decoration
 
 Mesh Interface Variables

+ 1 - 0
tools/clang/include/clang/SPIRV/FeatureManager.h

@@ -43,6 +43,7 @@ enum class Extension {
   EXT_descriptor_indexing,
   EXT_fragment_fully_covered,
   EXT_fragment_invocation_density,
+  EXT_mesh_shader,
   EXT_shader_stencil_export,
   EXT_shader_viewport_index_layer,
   AMD_gpu_shader_half_float,

+ 14 - 0
tools/clang/include/clang/SPIRV/SpirvBuilder.h

@@ -459,6 +459,20 @@ public:
   /// \brief Creates an OpEndPrimitive instruction.
   void createEndPrimitive(SourceLocation, SourceRange range = {});
 
+  /// \brief Creates an OpEmitMeshTasksEXT instruction.
+  void createEmitMeshTasksEXT(SpirvInstruction* xDim,
+                              SpirvInstruction* yDim,
+                              SpirvInstruction* zDim,
+                              SourceLocation loc,
+                              SpirvInstruction *payload = nullptr,
+                              SourceRange range = {});
+
+  /// \brief Creates an OpSetMeshOutputsEXT instruction.
+  void createSetMeshOutputsEXT(SpirvInstruction* vertCount,
+                               SpirvInstruction* primCount,
+                               SourceLocation loc,
+                               SourceRange range = {});
+
   /// \brief Creates an OpArrayLength instruction.
   SpirvArrayLength *createArrayLength(QualType resultType, SourceLocation loc,
                                       SpirvInstruction *structure,

+ 59 - 1
tools/clang/include/clang/SPIRV/SpirvInstruction.h

@@ -84,6 +84,7 @@ public:
     IK_Switch,              // OpSwitch
     IK_Unreachable,         // OpUnreachable
     IK_RayTracingTerminate, // OpIgnoreIntersectionKHR/OpTerminateRayKHR
+    IK_EmitMeshTasksEXT,    // OpEmitMeshTasksEXT
 
     // Normal instruction kinds
     // In alphabetical order
@@ -107,6 +108,8 @@ public:
     IK_EndPrimitive, // OpEndPrimitive
     IK_EmitVertex,   // OpEmitVertex
 
+    IK_SetMeshOutputsEXT,       // OpSetMeshOutputsEXT
+
     // The following section is for group non-uniform instructions.
     // Used by LLVM-style RTTI; order matters.
     IK_GroupNonUniformBinaryOp, // Group non-uniform binary operations
@@ -664,7 +667,7 @@ public:
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() >= IK_Branch &&
-           inst->getKind() <= IK_RayTracingTerminate;
+           inst->getKind() <= IK_EmitMeshTasksEXT;
   }
 
 protected:
@@ -2153,6 +2156,61 @@ private:
   SpirvExtInstImport *instructionSet;
 };
 
+/// \brief OpEmitMeshTasksEXT instruction.
+class SpirvEmitMeshTasksEXT : public SpirvInstruction {
+public:
+  SpirvEmitMeshTasksEXT(SpirvInstruction* xDim,
+                        SpirvInstruction* yDim,
+                        SpirvInstruction* zDim,
+                        SpirvInstruction* payload,
+                        SourceLocation loc, SourceRange range = {});
+
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvEmitMeshTasksEXT)
+
+  // For LLVM-style RTTI
+  static bool classof(const SpirvInstruction *inst) {
+    return inst->getKind() == IK_EmitMeshTasksEXT;
+  }
+
+  bool invokeVisitor(Visitor *v) override;
+
+  SpirvInstruction *getXDimension() const { return xDim; }
+  SpirvInstruction *getYDimension() const { return yDim; }
+  SpirvInstruction *getZDimension() const { return zDim; }
+  SpirvInstruction *getPayload() const { return payload; }
+
+private:
+  SpirvInstruction *xDim;
+  SpirvInstruction *yDim;
+  SpirvInstruction *zDim;
+  SpirvInstruction *payload;
+};
+
+/// \brief OpSetMeshOutputsEXT instruction.
+class SpirvSetMeshOutputsEXT : public SpirvInstruction {
+public:
+  SpirvSetMeshOutputsEXT(SpirvInstruction* vertCount,
+                         SpirvInstruction* primCount, 
+                         SourceLocation loc,
+                         SourceRange range = {});
+
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvSetMeshOutputsEXT)
+
+  // For LLVM-style RTTI
+  static bool classof(const SpirvInstruction *inst) {
+    return inst->getKind() == IK_SetMeshOutputsEXT;
+  }
+
+  bool invokeVisitor(Visitor *v) override;
+
+  SpirvInstruction *getVertexCount() const { return vertCount; }
+  SpirvInstruction *getPrimitiveCount() const { return primCount; }
+
+private:
+  SpirvInstruction *vertCount;
+  SpirvInstruction *primCount;
+};
+
 class SpirvDebugInfoNone : public SpirvDebugInstruction {
 public:
   SpirvDebugInfoNone();

+ 3 - 0
tools/clang/include/clang/SPIRV/SpirvVisitor.h

@@ -143,6 +143,9 @@ public:
   DEFINE_VISIT_METHOD(SpirvReadClock)
   DEFINE_VISIT_METHOD(SpirvRayTracingTerminateOpKHR)
   DEFINE_VISIT_METHOD(SpirvIntrinsicInstruction)
+
+  DEFINE_VISIT_METHOD(SpirvEmitMeshTasksEXT)
+  DEFINE_VISIT_METHOD(SpirvSetMeshOutputsEXT)
 #undef DEFINE_VISIT_METHOD
 
   const SpirvCodeGenOptions &getCodeGenOptions() const { return spvOptions; }

+ 22 - 3
tools/clang/lib/SPIRV/CapabilityVisitor.cpp

@@ -307,7 +307,8 @@ bool CapabilityVisitor::visit(SpirvDecoration *decor) {
     case spv::BuiltIn::PrimitiveId: {
       // PrimitiveID can be used as PSIn or MSPOut.
       if (shaderModel == spv::ExecutionModel::Fragment ||
-          shaderModel == spv::ExecutionModel::MeshNV)
+          shaderModel == spv::ExecutionModel::MeshNV   ||
+          shaderModel == spv::ExecutionModel::MeshEXT)
         addCapability(spv::Capability::Geometry);
       break;
     }
@@ -324,7 +325,8 @@ bool CapabilityVisitor::visit(SpirvDecoration *decor) {
           addCapability(spv::Capability::ShaderViewportIndexLayerEXT);
         }
       } else if (shaderModel == spv::ExecutionModel::Fragment ||
-                 shaderModel == spv::ExecutionModel::MeshNV) {
+                 shaderModel == spv::ExecutionModel::MeshNV   ||
+                 shaderModel == spv::ExecutionModel::MeshEXT) {
         // SV_RenderTargetArrayIndex can be used as PSIn or MSPOut.
         addCapability(spv::Capability::Geometry);
       }
@@ -343,7 +345,8 @@ bool CapabilityVisitor::visit(SpirvDecoration *decor) {
         }
       } else if (shaderModel == spv::ExecutionModel::Fragment ||
                  shaderModel == spv::ExecutionModel::Geometry ||
-                 shaderModel == spv::ExecutionModel::MeshNV) {
+                 shaderModel == spv::ExecutionModel::MeshNV   ||
+                 shaderModel == spv::ExecutionModel::MeshEXT) {
         // SV_ViewportArrayIndex can be used as PSIn or GSOut or MSPOut.
         addCapability(spv::Capability::MultiViewport);
       }
@@ -558,6 +561,17 @@ bool CapabilityVisitor::visitInstruction(SpirvInstruction *instr) {
     }
   }
 
+  case spv::Op::OpSetMeshOutputsEXT:
+  case spv::Op::OpEmitMeshTasksEXT: {
+    if (featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)) {
+      featureManager.requestTargetEnv(SPV_ENV_UNIVERSAL_1_4, "MeshShader",
+                                     {});
+      addCapability(spv::Capability::MeshShadingEXT);
+      addExtension(Extension::EXT_mesh_shader, "SPV_EXT_mesh_shader", {});
+    }
+    break;
+  }
+
   default:
     break;
   }
@@ -603,6 +617,11 @@ bool CapabilityVisitor::visit(SpirvEntryPoint *entryPoint) {
     addCapability(spv::Capability::MeshShadingNV);
     addExtension(Extension::NV_mesh_shader, "SPV_NV_mesh_shader", {});
     break;
+  case spv::ExecutionModel::MeshEXT:
+  case spv::ExecutionModel::TaskEXT:
+    addCapability(spv::Capability::MeshShadingEXT);
+    addExtension(Extension::EXT_mesh_shader, "SPV_EXT_mesh_shader", {});
+    break;
   default:
     llvm_unreachable("found unknown shader model");
     break;

+ 75 - 22
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -774,12 +774,42 @@ bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl,
       if (!isVectorType(type, nullptr, &verticesPerPrim)) {
         assert(isScalarType(type));
       }
-      arraySize = arraySize * verticesPerPrim;
-      QualType arrayType = astContext.getConstantArrayType(
-          astContext.UnsignedIntTy, llvm::APInt(32, arraySize),
-          clang::ArrayType::Normal, 0);
-      stageVarInstructions[cast<DeclaratorDecl>(decl)] = getBuiltinVar(
-          spv::BuiltIn::PrimitiveIndicesNV, arrayType, decl->getLocation());
+
+      spv::BuiltIn builtinID = spv::BuiltIn::Max;
+      if (featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)) {
+        // For EXT_mesh_shader, set builtin type as PrimitivePoint/Line/TriangleIndicesEXT
+        // based on the vertices per primitive
+        switch (verticesPerPrim) { 
+          case 1:
+            builtinID = spv::BuiltIn::PrimitivePointIndicesEXT;
+            break;
+          case 2:
+            builtinID = spv::BuiltIn::PrimitiveLineIndicesEXT;
+            break;
+          case 3:
+            builtinID = spv::BuiltIn::PrimitiveTriangleIndicesEXT;
+            break;
+          default:
+            break;
+        }
+        QualType arrayType = astContext.getConstantArrayType(
+            type, llvm::APInt(32, arraySize), clang::ArrayType::Normal, 0);
+
+        stageVarInstructions[cast<DeclaratorDecl>(decl)] =
+            getBuiltinVar(builtinID, arrayType, decl->getLocation());
+      } else {
+          // For NV_mesh_shader, the built type is PrimitiveIndicesNV
+          builtinID = spv::BuiltIn::PrimitiveIndicesNV;
+
+          arraySize = arraySize * verticesPerPrim;
+          QualType arrayType = astContext.getConstantArrayType(
+              astContext.UnsignedIntTy, llvm::APInt(32, arraySize),
+              clang::ArrayType::Normal, 0);
+
+          stageVarInstructions[cast<DeclaratorDecl>(decl)] =
+              getBuiltinVar(builtinID, arrayType, decl->getLocation());
+      }
+
       return true;
     }
   }
@@ -853,7 +883,9 @@ bool DeclResultIdMapper::createStageInputVar(const ParmVarDecl *paramDecl,
   SemanticInfo inheritSemantic = {};
 
   if (paramDecl->hasAttr<HLSLPayloadAttr>()) {
-    spv::StorageClass sc = getStorageClassForSigPoint(sigPoint);
+    spv::StorageClass sc = (featureManager.isExtensionEnabled(Extension::EXT_mesh_shader))
+                           ? spv::StorageClass::TaskPayloadWorkgroupEXT
+                           : getStorageClassForSigPoint(sigPoint);
     return createPayloadStageVars(sigPoint, sc, paramDecl, /*asInput=*/true,
                                   type, "in.var", loadedValue);
   } else {
@@ -3132,9 +3164,12 @@ bool DeclResultIdMapper::createPayloadStageVars(
     }
     stageVars.push_back(stageVar);
 
-    // Decorate with PerTaskNV for mesh/amplification shader payload variables.
-    spvBuilder.decoratePerTaskNV(varInstr, payloadMemOffset,
-                                 varInstr->getSourceLocation());
+    if (!featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)) {
+      // Decorate with PerTaskNV for mesh/amplification shader payload
+      // variables.
+      spvBuilder.decoratePerTaskNV(varInstr, payloadMemOffset,
+                                   varInstr->getSourceLocation());
+    }
 
     if (asInput) {
       *value = spvBuilder.createLoad(type, varInstr, loc);
@@ -3386,9 +3421,13 @@ SpirvVariable *DeclResultIdMapper::getBuiltinVar(spv::BuiltIn builtIn,
   case spv::BuiltIn::LocalInvocationIndex:
     sc = spv::StorageClass::Input;
     break;
+  case spv::BuiltIn::TaskCountNV:
   case spv::BuiltIn::PrimitiveCountNV:
   case spv::BuiltIn::PrimitiveIndicesNV:
-  case spv::BuiltIn::TaskCountNV:
+  case spv::BuiltIn::PrimitivePointIndicesEXT:
+  case spv::BuiltIn::PrimitiveLineIndicesEXT:
+  case spv::BuiltIn::PrimitiveTriangleIndicesEXT:
+  case spv::BuiltIn::CullPrimitiveEXT:
     sc = spv::StorageClass::Output;
     break;
   default:
@@ -3795,9 +3834,9 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
     return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::ViewIndex,
                                          isPrecise, srcLoc);
   }
-    // According to DXIL spec, the InnerCoverage SV can only be used as PSIn.
-    // According to Vulkan spec, the FullyCoveredEXT BuiltIn can only be used as
-    // PSIn.
+  // According to DXIL spec, the InnerCoverage SV can only be used as PSIn.
+  // According to Vulkan spec, the FullyCoveredEXT BuiltIn can only be used as
+  // PSIn.
   case hlsl::Semantic::Kind::InnerCoverage: {
     stageVar->setIsSpirvBuiltin();
     return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::FullyCoveredEXT,
@@ -3807,14 +3846,6 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
   // VSOut, or PSIn. According to Vulkan spec, the FragSizeEXT BuiltIn can only
   // be used as VSOut, GSOut, MSOut or PSIn.
   case hlsl::Semantic::Kind::ShadingRate: {
-    QualType checkType = type->getAs<ReferenceType>()
-                             ? type->getAs<ReferenceType>()->getPointeeType()
-                             : type;
-    QualType scalarTy;
-    if (!isScalarType(checkType, &scalarTy) || !scalarTy->isIntegerType()) {
-      emitError("semantic ShadingRate must be interger scalar type", srcLoc);
-    }
-
     switch (sigPointKind) {
     case hlsl::SigPoint::Kind::PSIn:
       stageVar->setIsSpirvBuiltin();
@@ -3823,6 +3854,7 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
     case hlsl::SigPoint::Kind::VSOut:
     case hlsl::SigPoint::Kind::GSOut:
     case hlsl::SigPoint::Kind::MSOut:
+    case hlsl::SigPoint::Kind::MSPOut:
       stageVar->setIsSpirvBuiltin();
       return spvBuilder.addStageBuiltinVar(
           type, sc, BuiltIn::PrimitiveShadingRateKHR, isPrecise, srcLoc);
@@ -3834,6 +3866,27 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
     }
     break;
   }
+  // According to DXIL spec, the ShadingRate SV can only be used by
+  // MSPOut or PSIn.
+  // According to Vulkan spec, the CullPrimitiveEXT BuiltIn can only
+  // be used as MSOut.
+  case hlsl::Semantic::Kind::CullPrimitive: {
+    switch (sigPointKind) {
+    case hlsl::SigPoint::Kind::PSIn:
+      stageVar->setIsSpirvBuiltin();
+      return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::CullPrimitiveEXT,
+                                           isPrecise, srcLoc);
+    case hlsl::SigPoint::Kind::MSPOut:
+      stageVar->setIsSpirvBuiltin();
+      return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::CullPrimitiveEXT,
+                                           isPrecise, srcLoc);
+    default:
+      emitError("semantic CullPrimitive must be used only for PSIn, MSPOut",
+                srcLoc);
+      break;
+    }
+    break;
+  }
   default:
     emitError("semantic %0 unimplemented", srcLoc)
         << stageVar->getSemanticStr();

+ 24 - 0
tools/clang/lib/SPIRV/EmitVisitor.cpp

@@ -1961,6 +1961,30 @@ bool EmitVisitor::visit(SpirvIntrinsicInstruction *inst) {
   return true;
 }
 
+bool EmitVisitor::visit(SpirvEmitMeshTasksEXT *inst) { 
+  initInstruction(inst);
+
+  curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst->getXDimension()));
+  curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst->getYDimension()));
+  curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst->getZDimension()));
+  if (inst->getPayload() != nullptr)
+  {
+      curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst->getPayload()));
+  }
+
+  finalizeInstruction(&mainBinary);
+  return true;
+}
+bool EmitVisitor::visit(SpirvSetMeshOutputsEXT *inst) {
+  initInstruction(inst);
+
+  curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst->getVertexCount()));
+  curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst->getPrimitiveCount()));
+
+  finalizeInstruction(&mainBinary);
+  return true;
+}
+
 // EmitTypeHandler ------
 
 void EmitTypeHandler::initTypeInstruction(spv::Op op) {

+ 2 - 0
tools/clang/lib/SPIRV/EmitVisitor.h

@@ -300,6 +300,8 @@ public:
   bool visit(SpirvDebugTypeTemplate *) override;
   bool visit(SpirvDebugTypeTemplateParameter *) override;
   bool visit(SpirvIntrinsicInstruction *) override;
+  bool visit(SpirvEmitMeshTasksEXT *) override;
+  bool visit(SpirvSetMeshOutputsEXT *) override;
 
   using Visitor::visit;
 

+ 12 - 0
tools/clang/lib/SPIRV/FeatureManager.cpp

@@ -87,6 +87,12 @@ FeatureManager::FeatureManager(DiagnosticsEngine &de,
              {});
   }
   targetEnv = *targetEnvOpt;
+
+  // Override the default mesh extension to SPV_EXT_mesh_shader when the
+  // target environment is SPIR-V 1.4 or above
+  if (isTargetEnvSpirv1p4OrAbove()) {
+    allowExtension("SPV_EXT_mesh_shader");
+  }
 }
 
 bool FeatureManager::allowExtension(llvm::StringRef name) {
@@ -166,6 +172,7 @@ Extension FeatureManager::getExtensionSymbol(llvm::StringRef name) {
             Extension::EXT_fragment_fully_covered)
       .Case("SPV_EXT_fragment_invocation_density",
             Extension::EXT_fragment_invocation_density)
+      .Case("SPV_EXT_mesh_shader", Extension::EXT_mesh_shader)
       .Case("SPV_EXT_shader_stencil_export",
             Extension::EXT_shader_stencil_export)
       .Case("SPV_EXT_shader_viewport_index_layer",
@@ -221,6 +228,8 @@ const char *FeatureManager::getExtensionName(Extension symbol) {
     return "SPV_EXT_fragment_fully_covered";
   case Extension::EXT_fragment_invocation_density:
     return "SPV_EXT_fragment_invocation_density";
+  case Extension::EXT_mesh_shader:
+    return "SPV_EXT_mesh_shader";
   case Extension::EXT_shader_stencil_export:
     return "SPV_EXT_shader_stencil_export";
   case Extension::EXT_shader_viewport_index_layer:
@@ -332,6 +341,9 @@ bool FeatureManager::enabledByDefault(Extension ext) {
     // the user explicitly asks for it.
   case Extension::EXT_demote_to_helper_invocation:
     return false;
+  case Extension::EXT_mesh_shader:
+    // Enabling EXT_mesh_shader only when the target environment is SPIR-V 1.4 or above
+    return false;
   default:
     return true;
   }

+ 22 - 0
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -838,7 +838,29 @@ void SpirvBuilder::createEndPrimitive(SourceLocation loc, SourceRange range) {
   auto *inst = new (context) SpirvEndPrimitive(loc, range);
   insertPoint->addInstruction(inst);
 }
+/// \brief Creates an OpEmitMeshTasksEXT instruction.
+void SpirvBuilder::createEmitMeshTasksEXT(SpirvInstruction* xDim,
+                                          SpirvInstruction* yDim,
+                                          SpirvInstruction* zDim,
+                                          SourceLocation loc,
+                                          SpirvInstruction *payload,
+                                          SourceRange range) {
+  assert(insertPoint && "null insert point");
+  auto *inst =
+      new (context) SpirvEmitMeshTasksEXT(xDim, yDim, zDim, payload, loc, range);
+  insertPoint->addInstruction(inst);
+}
 
+/// \brief Creates an OpSetMeshOutputsEXT instruction.
+void SpirvBuilder::createSetMeshOutputsEXT(SpirvInstruction* vertCount,
+                                           SpirvInstruction* primCount,
+                                           SourceLocation loc,
+                                           SourceRange range) {
+  assert(insertPoint && "null insert point");
+  auto *inst = new (context)
+      SpirvSetMeshOutputsEXT(vertCount, primCount, loc, range);
+  insertPoint->addInstruction(inst);
+}
 SpirvArrayLength *SpirvBuilder::createArrayLength(QualType resultType,
                                                   SourceLocation loc,
                                                   SpirvInstruction *structure,

+ 98 - 55
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -761,7 +761,7 @@ void SpirvEmitter::HandleTranslationUnit(ASTContext &context) {
     const FunctionInfo *entryInfo = workQueue[i];
     assert(entryInfo->isEntryFunction);
     spvBuilder.addEntryPoint(
-        getSpirvShaderStage(entryInfo->shaderModelKind),
+        getSpirvShaderStage(entryInfo->shaderModelKind, featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)),
         entryInfo->entryFunction, getEntryPointName(entryInfo),
         getInterfacesForEntryPoint(entryInfo->entryFunction));
   }
@@ -7330,6 +7330,9 @@ void SpirvEmitter::assignToMSOutIndices(
     const llvm::SmallVector<SpirvInstruction *, 4> &indices) {
   assert(spvContext.isMS() && !indices.empty());
 
+  bool extMesh =
+      featureManager.isExtensionEnabled(Extension::EXT_mesh_shader);
+
   // Extract vertex index and vecComponent (if any).
   SpirvInstruction *vertIndex = indices.front();
   SpirvInstruction *vecComponent = nullptr;
@@ -7361,45 +7364,65 @@ void SpirvEmitter::assignToMSOutIndices(
   } else {
     // for "line" or "triangle" output topology.
     assert(numVertices == 2 || numVertices == 3);
-    // set baseOffset = vertIndex * numVertices.
-    auto *baseOffset = spvBuilder.createBinaryOp(
-        spv::Op::OpIMul, astContext.UnsignedIntTy, vertIndex,
-        spvBuilder.getConstantInt(astContext.UnsignedIntTy,
-                                  llvm::APInt(32, numVertices)),
-        loc);
+
     if (vecComponent) {
       // write an individual vector component of uint2 or uint3.
       assert(numValues == 1);
-      // set baseOffset = baseOffset + vecComponent.
-      baseOffset =
-          spvBuilder.createBinaryOp(spv::Op::OpIAdd, astContext.UnsignedIntTy,
-                                    baseOffset, vecComponent, loc);
-      // create accesschain for PrimitiveIndicesNV[baseOffset].
-      auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, var,
-                                               {baseOffset}, loc);
-      // finally create store for PrimitiveIndicesNV[baseOffset] = value.
-      spvBuilder.createStore(ptr, value, loc);
+      if (extMesh) {
+        // create accesschain for Primitive*IndicesEXT[vertIndex][vecComponent].
+        auto *ptr = spvBuilder.createAccessChain(
+            astContext.UnsignedIntTy, var, {vertIndex, vecComponent}, loc);
+        // finally create store for Primitive*IndicesEXT[vertIndex][vecComponent] = value.
+        spvBuilder.createStore(ptr, value, loc);
+      } else {
+        // set baseOffset = vertIndex * numVertices.
+        auto *baseOffset = spvBuilder.createBinaryOp(
+            spv::Op::OpIMul, astContext.UnsignedIntTy, vertIndex,
+            spvBuilder.getConstantInt(astContext.UnsignedIntTy,
+                                      llvm::APInt(32, numVertices)), loc);
+        // set baseOffset = baseOffset + vecComponent.
+        baseOffset =
+            spvBuilder.createBinaryOp(spv::Op::OpIAdd, astContext.UnsignedIntTy,
+                                      baseOffset, vecComponent, loc);
+        // create accesschain for PrimitiveIndicesNV[baseOffset].
+        auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, var,
+                                                 {baseOffset}, loc);
+        // finally create store for PrimitiveIndicesNV[baseOffset] = value.
+        spvBuilder.createStore(ptr, value, loc);
+      }
     } else {
-      // write all vector components of uint2 or uint3.
       assert(numValues == numVertices);
-      auto *curOffset = baseOffset;
-      for (uint32_t i = 0; i < numValues; ++i) {
-        if (i != 0) {
-          // set curOffset = baseOffset + i.
-          curOffset = spvBuilder.createBinaryOp(
-              spv::Op::OpIAdd, astContext.UnsignedIntTy, baseOffset,
-              spvBuilder.getConstantInt(astContext.UnsignedIntTy,
-                                        llvm::APInt(32, i)),
-              loc);
+      if (extMesh) {
+        // create accesschain for Primitive*IndicesEXT[vertIndex].
+        auto *ptr = spvBuilder.createAccessChain(varType, var, vertIndex, loc);
+        // finally create store for Primitive*IndicesEXT[vertIndex] = value.
+        spvBuilder.createStore(ptr, value, loc);
+      } else {
+        // set baseOffset = vertIndex * numVertices.
+        auto *baseOffset = spvBuilder.createBinaryOp(
+            spv::Op::OpIMul, astContext.UnsignedIntTy, vertIndex,
+            spvBuilder.getConstantInt(astContext.UnsignedIntTy,
+                                      llvm::APInt(32, numVertices)), loc);
+        // write all vector components of uint2 or uint3.
+        auto *curOffset = baseOffset;
+        for (uint32_t i = 0; i < numValues; ++i) {
+          if (i != 0) {
+            // set curOffset = baseOffset + i.
+            curOffset = spvBuilder.createBinaryOp(
+                spv::Op::OpIAdd, astContext.UnsignedIntTy, baseOffset,
+                spvBuilder.getConstantInt(astContext.UnsignedIntTy,
+                                          llvm::APInt(32, i)),
+                loc);
+          }
+          // create accesschain for PrimitiveIndicesNV[curOffset].
+          auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy,
+                                                   var, {curOffset}, loc);
+          // finally create store for PrimitiveIndicesNV[curOffset] = value[i].
+          spvBuilder.createStore(ptr,
+                                 spvBuilder.createCompositeExtract(
+                                     astContext.UnsignedIntTy, value, {i}, loc),
+                                 loc);
         }
-        // create accesschain for PrimitiveIndicesNV[curOffset].
-        auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, var,
-                                                 {curOffset}, loc);
-        // finally create store for PrimitiveIndicesNV[curOffset] = value[i].
-        spvBuilder.createStore(ptr,
-                               spvBuilder.createCompositeExtract(
-                                   astContext.UnsignedIntTy, value, {i}, loc),
-                               loc);
       }
     }
   }
@@ -11319,26 +11342,16 @@ void SpirvEmitter::processDispatchMesh(const CallExpr *callExpr) {
                                 /*isDevice*/ false,
                                 /*groupSync*/ true,
                                 /*isAllBarrier*/ false);
-
-  // 2) set TaskCountNV = threadX * threadY * threadZ.
-  auto *threadX = doExpr(args[0]);
-  auto *threadY = doExpr(args[1]);
-  auto *threadZ = doExpr(args[2]);
-  auto *var = declIdMapper.getBuiltinVar(spv::BuiltIn::TaskCountNV,
-                                         astContext.UnsignedIntTy, loc);
-  auto *taskCount = spvBuilder.createBinaryOp(
-      spv::Op::OpIMul, astContext.UnsignedIntTy, threadX,
-      spvBuilder.createBinaryOp(spv::Op::OpIMul, astContext.UnsignedIntTy,
-                                threadY, threadZ, loc, range),
-      loc, range);
-  spvBuilder.createStore(var, taskCount, loc, range);
-
-  // 3) create PerTaskNV out attribute block and store MeshPayload info.
+  
+  // 2) create PerTaskNV out attribute block and store MeshPayload info.
   const auto *sigPoint =
       hlsl::SigPoint::GetSigPoint(hlsl::DXIL::SigPointKind::MSOut);
-  spv::StorageClass sc = spv::StorageClass::Output;
+  spv::StorageClass sc = featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)
+          ? spv::StorageClass::TaskPayloadWorkgroupEXT
+          : spv::StorageClass::Output;
   auto *payloadArg = doExpr(args[3]);
   bool isValid = false;
+  const VarDecl *param = nullptr;
   if (const auto *implCastExpr = dyn_cast<CastExpr>(args[3])) {
     if (const auto *arg = dyn_cast<DeclRefExpr>(implCastExpr->getSubExpr())) {
       if (const auto *paramDecl = dyn_cast<VarDecl>(arg->getDecl())) {
@@ -11346,6 +11359,7 @@ void SpirvEmitter::processDispatchMesh(const CallExpr *callExpr) {
           isValid = declIdMapper.createPayloadStageVars(
               sigPoint, sc, paramDecl, /*asInput=*/false, paramDecl->getType(),
               "out.var", &payloadArg);
+          param = paramDecl;
         }
       }
     }
@@ -11354,6 +11368,26 @@ void SpirvEmitter::processDispatchMesh(const CallExpr *callExpr) {
     emitError("expected groupshared object as argument to DispatchMesh()",
               args[3]->getExprLoc());
   }
+
+  // 3) set up emit dimension.
+  auto *threadX = doExpr(args[0]);
+  auto *threadY = doExpr(args[1]);
+  auto *threadZ = doExpr(args[2]);
+
+  if (featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)) {
+    // for EXT_mesh_shader, create opEmitMeshTasksEXT.
+    spvBuilder.createEmitMeshTasksEXT(threadX, threadY, threadZ, loc, nullptr, range);
+  } else {
+    // for NV_mesh_shader, set TaskCountNV = threadX * threadY * threadZ.
+    auto *var = declIdMapper.getBuiltinVar(spv::BuiltIn::TaskCountNV,
+                                           astContext.UnsignedIntTy, loc);
+    auto *taskCount = spvBuilder.createBinaryOp(
+        spv::Op::OpIMul, astContext.UnsignedIntTy, threadX,
+        spvBuilder.createBinaryOp(spv::Op::OpIMul, astContext.UnsignedIntTy,
+                                  threadY, threadZ, loc, range),
+        loc, range);
+    spvBuilder.createStore(var, taskCount, loc, range);
+  }
 }
 
 void SpirvEmitter::processMeshOutputCounts(const CallExpr *callExpr) {
@@ -11362,9 +11396,14 @@ void SpirvEmitter::processMeshOutputCounts(const CallExpr *callExpr) {
   const auto args = callExpr->getArgs();
   const auto loc = callExpr->getExprLoc();
   const auto range = callExpr->getSourceRange();
-  auto *var = declIdMapper.getBuiltinVar(spv::BuiltIn::PrimitiveCountNV,
+
+  if (featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)) {
+    spvBuilder.createSetMeshOutputsEXT(doExpr(args[0]), doExpr(args[1]), loc, range);
+  } else {
+    auto *var = declIdMapper.getBuiltinVar(spv::BuiltIn::PrimitiveCountNV,
                                          astContext.UnsignedIntTy, loc);
-  spvBuilder.createStore(var, doExpr(args[1]), loc, range);
+    spvBuilder.createStore(var, doExpr(args[1]), loc, range);
+  }
 }
 
 SpirvConstant *SpirvEmitter::getValueZero(QualType type) {
@@ -11687,7 +11726,7 @@ hlsl::ShaderModel::Kind SpirvEmitter::getShaderModelKind(StringRef stageName) {
 }
 
 spv::ExecutionModel
-SpirvEmitter::getSpirvShaderStage(hlsl::ShaderModel::Kind smk) {
+SpirvEmitter::getSpirvShaderStage(hlsl::ShaderModel::Kind smk, bool extMeshShading) {
   switch (smk) {
   case hlsl::ShaderModel::Kind::Vertex:
     return spv::ExecutionModel::Vertex;
@@ -11714,9 +11753,13 @@ SpirvEmitter::getSpirvShaderStage(hlsl::ShaderModel::Kind smk) {
   case hlsl::ShaderModel::Kind::Callable:
     return spv::ExecutionModel::CallableNV;
   case hlsl::ShaderModel::Kind::Mesh:
-    return spv::ExecutionModel::MeshNV;
+    return extMeshShading ?
+           spv::ExecutionModel::MeshEXT: 
+           spv::ExecutionModel::MeshNV;
   case hlsl::ShaderModel::Kind::Amplification:
-    return spv::ExecutionModel::TaskNV;
+    return extMeshShading ?
+        spv::ExecutionModel::TaskEXT:
+        spv::ExecutionModel::TaskNV;
   default:
     llvm_unreachable("invalid shader model kind");
     break;

+ 1 - 1
tools/clang/lib/SPIRV/SpirvEmitter.h

@@ -749,7 +749,7 @@ private:
   spv::LoopControlMask translateLoopAttribute(const Stmt *, const Attr &);
 
   static hlsl::ShaderModel::Kind getShaderModelKind(StringRef stageName);
-  static spv::ExecutionModel getSpirvShaderStage(hlsl::ShaderModel::Kind smk);
+  static spv::ExecutionModel getSpirvShaderStage(hlsl::ShaderModel::Kind smk, bool);
 
   /// \brief Adds necessary execution modes for the hull/domain shaders based on
   /// the HLSL attributes of the entry point function.

+ 16 - 0
tools/clang/lib/SPIRV/SpirvInstruction.cpp

@@ -110,6 +110,8 @@ DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvRayQueryOpKHR)
 DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvReadClock)
 DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvRayTracingTerminateOpKHR)
 DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvIntrinsicInstruction)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvEmitMeshTasksEXT)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvSetMeshOutputsEXT)
 
 #undef DEFINE_INVOKE_VISITOR_FOR_CLASS
 
@@ -1107,5 +1109,19 @@ SpirvIntrinsicInstruction::SpirvIntrinsicInstruction(
       capabilities(capts.begin(), capts.end()),
       extensions(exts.begin(), exts.end()), instructionSet(set) {}
 
+SpirvEmitMeshTasksEXT::SpirvEmitMeshTasksEXT(
+    SpirvInstruction *xDim, SpirvInstruction *yDim, SpirvInstruction *zDim,
+    SpirvInstruction *payload, SourceLocation loc, SourceRange range)
+    : SpirvInstruction(IK_EmitMeshTasksEXT, spv::Op::OpEmitMeshTasksEXT,
+                       QualType(), loc, range),
+      xDim(xDim), yDim(yDim), zDim(zDim), payload(payload) {}
+
+SpirvSetMeshOutputsEXT::SpirvSetMeshOutputsEXT(
+    SpirvInstruction *vertCount, SpirvInstruction *primCount,
+    SourceLocation loc, SourceRange range)
+    : SpirvInstruction(IK_SetMeshOutputsEXT, spv::Op::OpSetMeshOutputsEXT,
+                       QualType(), loc, range),
+      vertCount(vertCount), primCount(primCount) {}
+
 } // namespace spirv
 } // namespace clang

+ 66 - 0
tools/clang/test/CodeGenSPIRV/meshshading.ext.amplification.hlsl

@@ -0,0 +1,66 @@
+// RUN: %dxc -T as_6_5 -fspv-target-env=vulkan1.1spirv1.4 -E main
+// CHECK:  OpCapability MeshShadingEXT
+// CHECK:  OpExtension "SPV_EXT_mesh_shader"
+// CHECK:  OpEntryPoint TaskEXT %main "main" [[drawid:%\d+]] %gl_LocalInvocationID %gl_WorkGroupID %gl_GlobalInvocationID %gl_LocalInvocationIndex %out_var_dummy %out_var_pos
+// CHECK:  OpExecutionMode %main LocalSize 128 1 1
+
+// CHECK:  OpDecorate [[drawid]] BuiltIn DrawIndex
+// CHECK:  OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId
+// CHECK:  OpDecorate %gl_WorkGroupID BuiltIn WorkgroupId
+// CHECK:  OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId
+// CHECK:  OpDecorate %gl_LocalInvocationIndex BuiltIn LocalInvocationIndex
+
+
+// CHECK:  %pld = OpVariable %_ptr_Workgroup_MeshPayload Workgroup
+// CHECK:  [[drawid]] = OpVariable %_ptr_Input_int Input
+// CHECK:  %gl_LocalInvocationID = OpVariable %_ptr_Input_v3uint Input
+// CHECK:  %gl_WorkGroupID = OpVariable %_ptr_Input_v3uint Input
+// CHECK:  %gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input
+// CHECK:  %gl_LocalInvocationIndex = OpVariable %_ptr_Input_uint Input
+// CHECK:  %out_var_dummy = OpVariable %_ptr_TaskPayloadWorkgroupEXT__arr_float_uint_10 TaskPayloadWorkgroupEXT
+// CHECK:  %out_var_pos = OpVariable %_ptr_TaskPayloadWorkgroupEXT_v4float TaskPayloadWorkgroupEXT
+struct MeshPayload {
+    float dummy[10];
+    float4 pos;
+};
+
+groupshared MeshPayload pld;
+
+#define NUM_THREADS 128
+
+[numthreads(NUM_THREADS, 1, 1)]
+void main(
+// CHECK:  %param_var_drawId = OpVariable %_ptr_Function_int Function
+// CHECK:  %param_var_gtid = OpVariable %_ptr_Function_v3uint Function
+// CHECK:  %param_var_gid = OpVariable %_ptr_Function_v2uint Function
+// CHECK:  %param_var_tid = OpVariable %_ptr_Function_uint Function
+// CHECK:  %param_var_tig = OpVariable %_ptr_Function_uint Function
+        [[vk::builtin("DrawIndex")]] in int drawId : DRAW,  // -> BuiltIn DrawIndex
+        in uint3 gtid : SV_GroupThreadID,
+        in uint2 gid : SV_GroupID,
+        in uint tid : SV_DispatchThreadID,
+        in uint tig : SV_GroupIndex)
+{
+// CHECK:  %drawId = OpFunctionParameter %_ptr_Function_int
+// CHECK:  %gtid = OpFunctionParameter %_ptr_Function_v3uint
+// CHECK:  %gid = OpFunctionParameter %_ptr_Function_v2uint
+// CHECK:  %tid = OpFunctionParameter %_ptr_Function_uint
+// CHECK:  %tig = OpFunctionParameter %_ptr_Function_uint
+// 
+// CHECK:  [[a:%\d+]] = OpAccessChain %_ptr_Workgroup_v4float %pld %int_1
+// CHECK:  OpStore [[a]] {{%\d+}}
+    pld.pos = float4(gtid.x, gid.y, tid, tig);
+
+// CHECK:  OpControlBarrier %uint_2 %uint_2 %uint_264
+// CHECK:  [[e:%\d+]] = OpLoad %MeshPayload %pld
+// CHECK:  [[f:%\d+]] = OpCompositeExtract %_arr_float_uint_10 [[e]] 0
+// CHECK:  OpStore %out_var_dummy [[f]]
+// CHECK:  [[g:%\d+]] = OpCompositeExtract %v4float [[e]] 1
+// CHECK:  OpStore %out_var_pos [[g]]
+// CHECK:  [[h:%\d+]] = OpLoad %int %drawId
+// CHECK:  [[i:%\d+]] = OpBitcast %uint [[h]]
+// CHECK:  [[j:%\d+]] = OpLoad %int %drawId
+// CHECK:  [[k:%\d+]] = OpBitcast %uint [[j]]
+// CHECK:  OpEmitMeshTasksEXT %uint_128 [[i]] [[k]]
+   DispatchMesh(NUM_THREADS, drawId, drawId, pld);
+}

+ 15 - 0
tools/clang/test/CodeGenSPIRV/meshshading.ext.ps.hlsl

@@ -0,0 +1,15 @@
+// RUN: %dxc -T ps_6_1 -E main
+
+// CHECK:      OpCapability MultiView
+// CHECK:      OpExtension "SPV_KHR_multiview"
+
+// CHECK:      OpEntryPoint Fragment
+// CHECK-SAME: [[viewindex:%\d+]]
+
+// CHECK:      OpDecorate [[viewindex]] BuiltIn ViewIndex
+
+// CHECK:      [[viewindex]] = OpVariable %_ptr_Input_uint Input
+
+float4 main(uint viewid: SV_ViewID) : SV_Target {
+    return viewid;
+}

+ 253 - 0
tools/clang/test/CodeGenSPIRV/meshshading.ext.triangle.mesh.hlsl

@@ -0,0 +1,253 @@
+// RUN: %dxc -T ms_6_5 -fspv-target-env=universal1.5 -E main
+// CHECK:  OpCapability MeshShadingEXT
+// CHECK:  OpExtension "SPV_EXT_mesh_shader"
+// CHECK:  OpEntryPoint MeshEXT %main "main" %gl_ClipDistance %gl_CullDistance %in_var_dummy %in_var_pos [[drawid:%\d+]] %gl_LocalInvocationID %gl_WorkGroupID %gl_GlobalInvocationID %gl_LocalInvocationIndex %gl_Position %gl_PointSize %out_var_USER %out_var_USER_ARR %out_var_USER_MAT [[primindices:%\d+]] %gl_PrimitiveID %gl_Layer %gl_ViewportIndex [[cullprim:%\d+]] [[primshadingrate:%\d+]] %out_var_PRIM_USER %out_var_PRIM_USER_ARR 
+// CHECK:  OpExecutionMode %main LocalSize 128 1 1
+// CHECK:  OpExecutionMode %main OutputTrianglesNV
+// CHECK:  OpExecutionMode %main OutputVertices 64
+// CHECK:  OpExecutionMode %main OutputPrimitivesNV 81
+
+// CHECK:  OpDecorate %gl_ClipDistance BuiltIn ClipDistance
+// CHECK:  OpDecorate %gl_CullDistance BuiltIn CullDistance
+// CHECK:  OpDecorate [[drawid]] BuiltIn DrawIndex
+// CHECK:  OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId
+// CHECK:  OpDecorate %gl_WorkGroupID BuiltIn WorkgroupId
+// CHECK:  OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId
+// CHECK:  OpDecorate %gl_LocalInvocationIndex BuiltIn LocalInvocationIndex
+// CHECK:  OpDecorate %gl_Position BuiltIn Position
+// CHECK:  OpDecorate %gl_PointSize BuiltIn PointSize
+// CHECK:  OpDecorate [[primindices]] BuiltIn PrimitiveTriangleIndicesEXT
+// CHECK:  OpDecorate %gl_PrimitiveID BuiltIn PrimitiveId
+// CHECK:  OpDecorate %gl_PrimitiveID PerPrimitiveNV
+// CHECK:  OpDecorate %gl_Layer BuiltIn Layer
+// CHECK:  OpDecorate %gl_Layer PerPrimitiveNV
+// CHECK:  OpDecorate %gl_ViewportIndex BuiltIn ViewportIndex
+// CHECK:  OpDecorate %gl_ViewportIndex PerPrimitiveNV
+// CHECK:  OpDecorate [[cullprim]] BuiltIn CullPrimitiveEXT
+// CHECK:  OpDecorate [[cullprim]] PerPrimitiveNV
+// CHECK:  OpDecorate [[primshadingrate]] BuiltIn PrimitiveShadingRateKHR
+// CHECK:  OpDecorate [[primshadingrate]] PerPrimitiveNV
+// CHECK:  OpDecorate %out_var_PRIM_USER PerPrimitiveNV
+// CHECK:  OpDecorate %out_var_PRIM_USER_ARR PerPrimitiveNV
+// CHECK:  OpDecorate %out_var_USER Location 0
+// CHECK:  OpDecorate %out_var_USER_ARR Location 1
+// CHECK:  OpDecorate %out_var_USER_MAT Location 3
+// CHECK:  OpDecorate %out_var_PRIM_USER Location 7
+// CHECK:  OpDecorate %out_var_PRIM_USER_ARR Location 8
+
+// CHECK:  %gl_ClipDistance = OpVariable %_ptr_Output__arr__arr_float_uint_5_uint_64 Output
+// CHECK:  %gl_CullDistance = OpVariable %_ptr_Output__arr__arr_float_uint_3_uint_64 Output
+// CHECK:  %in_var_dummy = OpVariable %_ptr_TaskPayloadWorkgroupEXT__arr_float_uint_10 TaskPayloadWorkgroupEXT
+// CHECK:  %in_var_pos = OpVariable %_ptr_TaskPayloadWorkgroupEXT_v4float TaskPayloadWorkgroupEXT
+// CHECK:  %gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input
+// CHECK:  %gl_LocalInvocationIndex = OpVariable %_ptr_Input_uint Input
+// CHECK:  %gl_Position = OpVariable %_ptr_Output__arr_v4float_uint_64 Output
+// CHECK:  %gl_PointSize = OpVariable %_ptr_Output__arr_float_uint_64 Output
+// CHECK:  %out_var_USER = OpVariable %_ptr_Output__arr_v2float_uint_64 Output
+// CHECK:  %out_var_USER_ARR = OpVariable %_ptr_Output__arr__arr_v4float_uint_2_uint_64 Output
+// CHECK:  %out_var_USER_MAT = OpVariable %_ptr_Output__arr_mat4v4float_uint_64 Output
+// CHECK:  [[primindices]] = OpVariable %_ptr_Output__arr_v3uint_uint_81 Output
+// CHECK:  %gl_PrimitiveID = OpVariable %_ptr_Output__arr_int_uint_81 Output
+// CHECK:  %gl_Layer = OpVariable %_ptr_Output__arr_int_uint_81 Output
+// CHECK:  %gl_ViewportIndex = OpVariable %_ptr_Output__arr_int_uint_81 Output
+// CHECK:  [[cullprim]] = OpVariable %_ptr_Output__arr_int_uint_81 Output
+// CHECK:  [[primshadingrate]] = OpVariable %_ptr_Output__arr_uint_uint_81 Output
+// CHECK:  %out_var_PRIM_USER = OpVariable %_ptr_Output__arr_v3float_uint_81 Output
+// CHECK:  %out_var_PRIM_USER_ARR = OpVariable %_ptr_Output__arr__arr_v4float_uint_2_uint_81 Output
+
+struct MeshPerVertex {
+    float4 position : SV_Position;                          // -> BuiltIn Position
+    [[vk::builtin("PointSize")]] float psize : PSIZE;       // -> BuiltIn PointSize
+    float3 clipdis4 : SV_ClipDistance4;                     // -> BuiltIn ClipDistance
+    float  culldis5 : SV_CullDistance5;                     // -> BuiltIn CullDistance
+    float2 clipdis3 : SV_ClipDistance3;                     // -> BuiltIn ClipDistance
+    float2 culldis6 : SV_CullDistance6;                     // -> BuiltIn CullDistance
+    float2 userVertAttr : USER;
+    float4 userVertAttrArr[2] : USER_ARR;
+    float4x4 userVertAttrMat : USER_MAT;
+};
+
+struct MeshPerPrimitive {
+    int primId : SV_PrimitiveID;                            // -> Builtin PrimitiveId
+    int layer  : SV_RenderTargetArrayIndex;                 // -> Builtin Layer
+    int vpIdx  : SV_ViewportArrayIndex;                     // -> Builtin ViewportIndex
+    int cullPrim : SV_CullPrimitive;                        // -> BUiltin CullPrimitiveEXT
+    uint shadingRate : SV_ShadingRate;                      // -> Builtin PrimitiveShadingRateEXT
+    float3 userPrimAttr : PRIM_USER;
+    float4 userPrimAttrArr[2] : PRIM_USER_ARR;
+};
+
+struct MeshPayload {
+    float dummy[10];
+    float4 pos;
+};
+
+#define MAX_VERT 64
+#define MAX_PRIM 81
+#define NUM_THREADS 128
+
+[outputtopology("triangle")]
+[numthreads(NUM_THREADS, 1, 1)]
+void main(
+// CHECK:  %param_var_verts = OpVariable %_ptr_Function__arr_MeshPerVertex_uint_64 Function
+// CHECK:  %param_var_primitiveInd = OpVariable %_ptr_Function__arr_v3uint_uint_81 Function
+// CHECK:  %param_var_prims = OpVariable %_ptr_Function__arr_MeshPerPrimitive_uint_81 Function
+// CHECK:  %param_var_pld = OpVariable %_ptr_Function_MeshPayload Function
+// CHECK:  %param_var_drawId = OpVariable %_ptr_Function_int Function
+// CHECK:  %param_var_gtid = OpVariable %_ptr_Function_v3uint Function
+// CHECK:  %param_var_gid = OpVariable %_ptr_Function_v2uint Function
+// CHECK:  %param_var_tid = OpVariable %_ptr_Function_uint Function
+// CHECK:  %param_var_tig = OpVariable %_ptr_Function_uint Function
+        out vertices MeshPerVertex verts[MAX_VERT],
+        out indices uint3 primitiveInd[MAX_PRIM],
+        out primitives MeshPerPrimitive prims[MAX_PRIM],
+        in payload MeshPayload pld,
+        [[vk::builtin("DrawIndex")]] in int drawId : DRAW,  // -> BuiltIn DrawIndex
+        in uint3 gtid : SV_GroupThreadID,
+        in uint2 gid : SV_GroupID,
+        in uint tid : SV_DispatchThreadID,
+        in uint tig : SV_GroupIndex)
+{
+// CHECK:  OpSetMeshOutputsEXT %uint_64 %uint_81
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+
+    // Directly assign to per-vertex attribute object.
+
+// CHECK:  OpAccessChain %_ptr_Output_float %gl_Position {{%\d+}} %uint_0
+// CHECK:  OpStore {{%\d+}} %float_11
+    verts[tid].position.x = 11.0;
+// CHECK:  OpAccessChain %_ptr_Output_float %gl_Position {{%\d+}} %uint_1
+// CHECK:  OpStore {{%\d+}} {{%\d+}}
+// CHECK:  OpAccessChain %_ptr_Output_float %gl_Position {{%\d+}} %uint_3
+// CHECK:  OpStore {{%\d+}} {{%\d+}}
+    verts[tid].position.yw = float2(12.0,14.0);
+// CHECK:  OpAccessChain %_ptr_Output_float %gl_Position {{%\d+}} %uint_2
+// CHECK:  OpStore {{%\d+}} %float_13
+    verts[tid].position[2] = 13.0;
+// CHECK:  OpAccessChain %_ptr_Output_float %gl_PointSize {{%\d+}}
+// CHECK:  OpStore {{%\d+}} %float_50
+    verts[tid].psize = 50.0;
+// CHECK:  OpIAdd %uint %uint_1 %uint_2
+// CHECK:  OpAccessChain %_ptr_Output_float %gl_ClipDistance {{%\d+}} {{%\d+}}
+// CHECK:  OpStore {{%\d+}} {{%\d+}}
+// CHECK:  OpIAdd %uint %uint_0 %uint_2
+// CHECK:  OpAccessChain %_ptr_Output_float %gl_ClipDistance {{%\d+}} {{%\d+}}
+// CHECK:  OpStore {{%\d+}} {{%\d+}}
+// CHECK:  OpIAdd %uint %uint_2 %uint_2
+// CHECK:  OpAccessChain %_ptr_Output_float %gl_ClipDistance {{%\d+}} {{%\d+}}
+// CHECK:  OpStore {{%\d+}} {{%\d+}}
+    verts[tid].clipdis4.yxz = float3(0.0,1.0,2.0);
+// CHECK:  OpIAdd %uint %uint_0 %uint_2
+// CHECK:  OpAccessChain %_ptr_Output_float %gl_ClipDistance {{%\d+}} {{%\d+}}
+// CHECK:  OpStore {{%\d+}} %float_10
+    verts[tid].clipdis4[0] = 10.0;
+// CHECK:  OpAccessChain %_ptr_Output_float %gl_CullDistance {{%\d+}} %uint_0
+// CHECK:  OpStore {{%\d+}} %float_5
+    verts[tid].culldis5 = 5.0;
+// CHECK:  OpIAdd %uint %uint_0 %uint_0
+// CHECK:  OpAccessChain %_ptr_Output_float %gl_ClipDistance {{%\d+}} {{%\d+}}
+// CHECK:  OpStore {{%\d+}} {{%\d+}}
+// CHECK:  OpIAdd %uint %uint_0 %uint_1
+// CHECK:  OpAccessChain %_ptr_Output_float %gl_ClipDistance {{%\d+}} {{%\d+}}
+// CHECK:  OpStore {{%\d+}} {{%\d+}}
+    verts[tid].clipdis3 = float2(11.0,12.0);
+// CHECK:  OpIAdd %uint %uint_0 %uint_1
+// CHECK:  OpAccessChain %_ptr_Output_float %gl_CullDistance {{%\d+}} {{%\d+}}
+// CHECK:  OpStore {{%\d+}} %float_13
+    verts[tid].culldis6[0] = 13.0;
+// CHECK:  OpIAdd %uint %uint_1 %uint_1
+// CHECK:  OpAccessChain %_ptr_Output_float %gl_CullDistance {{%\d+}} {{%\d+}}
+// CHECK:  OpStore {{%\d+}} %float_14
+    verts[tid].culldis6.y = 14.0;
+// CHECK:  OpAccessChain %_ptr_Output_v2float %out_var_USER {{%\d+}}
+// CHECK:  OpStore {{%\d+}} {{%\d+}}
+    verts[tid].userVertAttr = float2(9.0, 10.0);
+// CHECK:  OpAccessChain %_ptr_Output_v4float %out_var_USER_ARR {{%\d+}} %int_0
+// CHECK:  OpStore {{%\d+}} {{%\d+}}
+    verts[tid].userVertAttrArr[0] = float4(17.0, 18.0, 19.0, 20.0);
+// CHECK:  OpAccessChain %_ptr_Output_v4float %out_var_USER_ARR {{%\d+}} %int_1
+// CHECK:  OpStore {{%\d+}} {{%\d+}}
+    verts[tid].userVertAttrArr[1] = float4(27.0, 28.0, 29.0, 30.0);
+// CHECK:  OpAccessChain %_ptr_Output_v4float %out_var_USER_MAT {{%\d+}} %uint_3
+// CHECK:  OpStore {{%\d+}} {{%\d+}}
+    verts[tid].userVertAttrMat[3] = float4(7.0, 8.0, 9.0, 10.0);
+
+    // Indirectly assign to per-vertex attribute object.
+    MeshPerVertex vert;
+    vert.position = pld.pos;
+    vert.psize = 50.0;
+    vert.clipdis4.yxz = float3(0.0,1.0,2.0);
+    vert.clipdis4[0] = 10.0;
+    vert.culldis5 = 5.0;
+    vert.clipdis3 = float2(11.0,12.0);
+    vert.culldis6[0] = 13.0;
+    vert.culldis6.y = 14.0;
+    vert.userVertAttr = float2(9.0, 10.0);
+    vert.userVertAttrArr[0] = float4(17.0, 18.0, 19.0, 20.0);
+    vert.userVertAttrArr[1] = float4(27.0, 28.0, 29.0, 30.0);
+    vert.userVertAttrMat[3] = float4(7.0, 8.0, 9.0, 10.0);
+    verts[tid+1] = vert;
+
+    // Directly assign to per-vertex attribute object.
+ 
+// CHECK:  OpAccessChain %_ptr_Output_int %gl_PrimitiveID {{%\d+}}
+// CHECK:  OpStore {{%\d+}} %int_10
+    prims[tig].primId = 10;
+// CHECK:  OpAccessChain %_ptr_Output_int %gl_Layer {{%\d+}}
+// CHECK:  OpStore {{%\d+}} %int_11
+    prims[tig].layer = 11;
+// CHECK:  OpAccessChain %_ptr_Output_int %gl_ViewportIndex {{%\d+}}
+// CHECK:  OpStore {{%\d+}} %int_12
+    prims[tig].vpIdx = 12;
+// CHECK:  OpAccessChain %_ptr_Output_int [[cullprim]] {{%\d+}}
+// CHECK:  OpStore {{%\d+}} %int_13
+    prims[tig].cullPrim = 13;
+// CHECK:  OpAccessChain %_ptr_Output_uint [[primshadingrate]] {{%\d+}}
+// CHECK:  OpStore {{%\d+}} %uint_14
+    prims[tig].shadingRate = 14;
+
+// CHECK:  OpAccessChain %_ptr_Output_v4float %out_var_PRIM_USER_ARR {{%\d+}} %int_0
+// CHECK:  OpStore {{%\d+}} {{%\d+}}
+    prims[tig].userPrimAttrArr[0] = float4(4.0,5.0,6.0,7.0);
+// CHECK:  OpAccessChain %_ptr_Output_v4float %out_var_PRIM_USER_ARR {{%\d+}} %int_1
+// CHECK:  OpStore {{%\d+}} {{%\d+}}
+    prims[tig].userPrimAttrArr[1] = float4(8.0,9.0,10.0,11.0);
+// CHECK:  OpAccessChain %_ptr_Output_v3float %out_var_PRIM_USER {{%\d+}}
+// CHECK:  OpStore {{%\d+}} {{%\d+}}
+    prims[tig].userPrimAttr = float3(14.0,15.0,16.0);
+
+    // Indirectly assign to per-vertex attribute object.
+    MeshPerPrimitive prim;
+    prim.primId = 10;
+    prim.layer = 11;
+    prim.vpIdx = 12;
+    prim.cullPrim = 1;
+    prim.shadingRate = 0xa;
+    prim.userPrimAttrArr[0] = float4(4.0,5.0,6.0,7.0);
+    prim.userPrimAttrArr[1] = float4(8.0,9.0,10.0,11.0);
+    prim.userPrimAttr = float3(14.0,15.0,16.0);
+    prims[tig+1] = prim;
+ 
+    // Assign primitive indices.
+
+// CHECK:  OpAccessChain %_ptr_Output_uint [[primindices]] %int_4 %uint_0
+// CHECK:  OpStore {{%\d+}} %uint_1
+    primitiveInd[4].x = 1;
+// CHECK:  OpCompositeExtract %uint {{%\d+}} 0
+// CHECK:  OpAccessChain %_ptr_Output_uint [[primindices]] %int_4 %uint_1
+// CHECK:  OpStore {{%\d+}} {{%\d+}}
+// CHECK:  OpCompositeExtract %uint {{%\d+}} 1
+// CHECK:  OpAccessChain %_ptr_Output_uint [[primindices]] %int_4 %uint_2
+// CHECK:  OpStore {{%\d+}} {{%\d+}}
+    primitiveInd[4].yz = uint2(2,3);
+// CHECK:  OpAccessChain %_ptr_Output_uint [[primindices]] %int_2 %uint_1
+// CHECK:  OpStore {{%\d+}} %uint_2
+    primitiveInd[2].y = 2;
+// CHECK:  OpAccessChain %_ptr_Output_uint [[primindices]] %int_2 %uint_2
+// CHECK:  OpStore {{%\d+}} %uint_1
+    primitiveInd[2][2] = 1;
+// CHECK:  OpLoad %uint %tid
+// CHECK:  OpAccessChain %_ptr_Output_v3uint [[primindices]] {{%\d+}}
+// CHECK:  OpStore {{%\d+}} {{%\d+}}
+    primitiveInd[tid] = uint3(11,12,13);
+}

+ 26 - 15
tools/clang/test/CodeGenSPIRV/meshshading.nv.amplification.hlsl

@@ -1,7 +1,7 @@
 // RUN: %dxc -T as_6_5 -E main
 // CHECK:  OpCapability MeshShadingNV
 // CHECK:  OpExtension "SPV_NV_mesh_shader"
-// CHECK:  OpEntryPoint TaskNV %main "main" [[drawid:%\d+]] %gl_LocalInvocationID %gl_WorkGroupID %gl_GlobalInvocationID %gl_LocalInvocationIndex [[taskcount:%\d+]] %out_var_dummy %out_var_pos
+// CHECK:  OpEntryPoint TaskNV %main "main" [[drawid:%\d+]] %gl_LocalInvocationID %gl_WorkGroupID %gl_GlobalInvocationID %gl_LocalInvocationIndex %out_var_dummy %out_var_pos [[taskcount:%\d+]]
 // CHECK:  OpExecutionMode %main LocalSize 128 1 1
 
 // CHECK:  OpDecorate [[drawid]] BuiltIn DrawIndex
@@ -9,7 +9,6 @@
 // CHECK:  OpDecorate %gl_WorkGroupID BuiltIn WorkgroupId
 // CHECK:  OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId
 // CHECK:  OpDecorate %gl_LocalInvocationIndex BuiltIn LocalInvocationIndex
-// CHECK:  OpDecorate [[taskcount]] BuiltIn TaskCountNV
 
 struct MeshPayload {
 // CHECK:  OpDecorate %out_var_dummy PerTaskNV
@@ -20,14 +19,17 @@ struct MeshPayload {
     float4 pos;
 };
 
+// CHECK:  OpDecorate [[taskcount]] BuiltIn TaskCountNV
+
 // CHECK:  %pld = OpVariable %_ptr_Workgroup_MeshPayload Workgroup
+// CHECK:  [[drawid]] = OpVariable %_ptr_Input_int Input
 groupshared MeshPayload pld;
 
 // CHECK:  %gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input
 // CHECK:  %gl_LocalInvocationIndex = OpVariable %_ptr_Input_uint Input
-// CHECK:  [[taskcount]] = OpVariable %_ptr_Output_uint Output
 // CHECK:  %out_var_dummy = OpVariable %_ptr_Output__arr_float_uint_10 Output
 // CHECK:  %out_var_pos = OpVariable %_ptr_Output_v4float Output
+// CHECK:  [[taskcount]] = OpVariable %_ptr_Output_uint Output
 
 #define NUM_THREADS 128
 
@@ -44,19 +46,28 @@ void main(
         in uint tid : SV_DispatchThreadID,
         in uint tig : SV_GroupIndex)
 {
+// CHECK:  %drawId = OpFunctionParameter %_ptr_Function_int
+// CHECK:  %gtid = OpFunctionParameter %_ptr_Function_v3uint
+// CHECK:  %gid = OpFunctionParameter %_ptr_Function_v2uint
+// CHECK:  %tid = OpFunctionParameter %_ptr_Function_uint
+// CHECK:  %tig = OpFunctionParameter %_ptr_Function_uint
 
-// CHECK:  [[a:%\d+]] = OpAccessChain %_ptr_Workgroup_v4float %pld %int_1
-// CHECK:  OpStore [[a]] {{%\d+}}
-    pld.pos = float4(3.0,4.0,5.0,6.0);
+// CHECK:  [[b:%\d+]] = OpAccessChain %_ptr_Workgroup_v4float %pld %int_1
+// CHECK:  OpStore [[b]] {{%\d+}}
+    pld.pos = float4(gtid.x, gid.y, tid, tig);
 
 // CHECK:  OpControlBarrier %uint_2 %uint_2 %uint_264
-// CHECK:  [[c:%\d+]] = OpIMul %uint %uint_1 %uint_1
-// CHECK:  [[d:%\d+]] = OpIMul %uint %uint_128 [[c]]
-// CHECK:  OpStore [[taskcount]] [[d]]
-// CHECK:  [[e:%\d+]] = OpLoad %MeshPayload %pld
-// CHECK:  [[f:%\d+]] = OpCompositeExtract %_arr_float_uint_10 [[e]] 0
-// CHECK:  OpStore %out_var_dummy [[f]]
-// CHECK:  [[g:%\d+]] = OpCompositeExtract %v4float [[e]] 1
-// CHECK:  OpStore %out_var_pos [[g]]
-    DispatchMesh(NUM_THREADS, 1, 1, pld);
+// CHECK:  [[c:%\d+]] = OpLoad %MeshPayload %pld
+// CHECK:  [[d:%\d+]] = OpCompositeExtract %_arr_float_uint_10 [[c]] 0
+// CHECK:  OpStore %out_var_dummy [[d]]
+// CHECK:  [[e:%\d+]] = OpCompositeExtract %v4float [[c]] 1
+// CHECK:  OpStore %out_var_pos [[e]]
+// CHECK:  [[f:%\d+]] = OpLoad %int %drawId
+// CHECK:  [[g:%\d+]] = OpBitcast %uint [[f]]
+// CHECK:  [[h:%\d+]] = OpLoad %int %drawId
+// CHECK:  [[i:%\d+]] = OpBitcast %uint [[h]]
+// CHECK:  [[j:%\d+]] = OpIMul %uint [[g]] [[i]]
+// CHECK:  [[k:%\d+]] = OpIMul %uint %uint_128 [[j]]
+// CHECK:  OpStore [[taskcount]] [[k]]
+   DispatchMesh(NUM_THREADS, drawId, drawId, pld);
 }

+ 0 - 0
tools/clang/test/CodeGenSPIRV/meshshading.nv.fncall.amplification.vulkan1.2.hlsl → tools/clang/test/CodeGenSPIRV/meshshading.nv.error.fncall.amplification.vulkan1.2.hlsl


+ 9 - 4
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -2735,6 +2735,15 @@ TEST_F(FileTest, VulkanEarlyAndLateTestsStencilRefErrorBack) {
               Expect::Failure);
 }
 
+// === MeshShading EXT examples ===
+TEST_F(FileTest, MeshShadingEXTMeshTriangle) {
+  runFileTest("meshshading.ext.triangle.mesh.hlsl");
+}
+
+TEST_F(FileTest, MeshShadingEXTAmplification) {
+  runFileTest("meshshading.ext.amplification.hlsl");
+}
+
 // === MeshShading NV examples ===
 TEST_F(FileTest, MeshShadingNVMeshTriangle) {
   // TODO: Re-enable spirv-val once issue#3006 is fixed.
@@ -2850,10 +2859,6 @@ TEST_F(FileTest, Vk1p2RemoveBufferBlockPtrToPtr2) {
 // -fspv-target-env=vulkan1.2 option to make sure that enabling
 // Vulkan1.2 also enables Vulkan1.1.
 TEST_F(FileTest, CompatibilityWithVk1p1) {
-  // TODO: Re-enable spirv-val once issue#3006 is fixed.
-  runFileTest("meshshading.nv.fncall.amplification.vulkan1.2.hlsl",
-              Expect::Success,
-              /* runValidation */ false);
   runFileTest("sm6.quad-read-across-diagonal.vulkan1.2.hlsl");
   runFileTest("sm6.quad-read-across-x.vulkan1.2.hlsl");
   runFileTest("sm6.quad-read-across-y.vulkan1.2.hlsl");