Răsfoiți Sursa

[spirv] Support WaveGetLaneCount() and WaveGetLaneIndex() (#1077)

They are translated into SPIR-V builtin varibles. The translation
requires the SPV_KHR_shader_ballot extension.
Lei Zhang 7 ani în urmă
părinte
comite
4221a698e1

+ 13 - 1
docs/SPIR-V.rst

@@ -2257,7 +2257,6 @@ element is the height, and the third is the elements.
 The ``OpImageQuerySize`` instruction is used to get a uint3. The first element is the width, the second
 element is the height, and the third element is the depth.
 
-
 HLSL Shader Stages
 ==================
 
@@ -2424,6 +2423,19 @@ behind ``T`` will be flushed before SPIR-V ``OpEmitVertex`` instruction is
 generated. ``.RestartStrip()`` method calls will be translated into the SPIR-V
 ``OpEndPrimitive`` instruction.
 
+Shader Model 6.0 Wave Intrinsics
+================================
+
+Shader Model 6.0 introduces a set of wave operations, which are translated
+according to the following table:
+
+====================== ============================= =========================
+      Intrinsic               SPIR-V BuiltIn                Extension
+====================== ============================= =========================
+``WaveGetLaneCount()`` ``SubgroupSize``              ``SPV_KHR_shader_ballot``
+``WaveGetLaneIndex()`` ``SubgroupLocalInvocationId`` ``SPV_KHR_shader_ballot``
+====================== ============================= =========================
+
 Vulkan Command-line Options
 ===========================
 

+ 1 - 1
external/SPIRV-Headers

@@ -1 +1 @@
-Subproject commit e0282aa7d54631502b4af567a85d3b6565fd5464
+Subproject commit 2bf91d32b2ce17df9ca6c1e62cf478b24e7d2644

+ 52 - 0
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -1673,6 +1673,58 @@ void DeclResultIdMapper::decoratePSInterpolationMode(const NamedDecl *decl,
   }
 }
 
+uint32_t DeclResultIdMapper::getBuiltinVar(spv::BuiltIn builtIn) {
+  // Guarantee uniqueness
+  switch (builtIn) {
+  case spv::BuiltIn::SubgroupSize:
+    if (laneCountBuiltinId)
+      return laneCountBuiltinId;
+    break;
+  case spv::BuiltIn::SubgroupLocalInvocationId:
+    if (laneIndexBuiltinId)
+      return laneIndexBuiltinId;
+    break;
+  default:
+    // Only allow the two cases we know about
+    assert(false && "unsupported builtin case");
+    return 0;
+  }
+
+  // Both of them require the SPV_KHR_shader_ballot extension.
+  theBuilder.addExtension("SPV_KHR_shader_ballot");
+  theBuilder.requireCapability(spv::Capability::SubgroupBallotKHR);
+
+  uint32_t type = theBuilder.getUint32Type();
+
+  // Create a dummy StageVar for this builtin variable
+  const uint32_t varId =
+      theBuilder.addStageBuiltinVar(type, spv::StorageClass::Input, builtIn);
+
+  const hlsl::SigPoint *sigPoint =
+      hlsl::SigPoint::GetSigPoint(hlsl::SigPointFromInputQual(
+          hlsl::DxilParamInputQual::In, shaderModel.GetKind(),
+          /*isPatchConstant=*/false));
+
+  StageVar stageVar(sigPoint, /*semaStr=*/"", hlsl::Semantic::GetInvalid(),
+                    /*semaName=*/"", /*semaIndex=*/0, /*builtinAttr=*/nullptr,
+                    type);
+
+  stageVar.setIsSpirvBuiltin();
+  stageVar.setSpirvId(varId);
+  stageVars.push_back(stageVar);
+
+  switch (builtIn) {
+  case spv::BuiltIn::SubgroupSize:
+    laneCountBuiltinId = varId;
+    break;
+  case spv::BuiltIn::SubgroupLocalInvocationId:
+    laneIndexBuiltinId = varId;
+    break;
+  }
+
+  return varId;
+}
+
 uint32_t DeclResultIdMapper::createSpirvStageVar(StageVar *stageVar,
                                                  const NamedDecl *decl,
                                                  const llvm::StringRef name,

+ 13 - 1
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -255,6 +255,9 @@ public:
                             ModuleBuilder &builder,
                             const EmitSPIRVOptions &spirvOptions);
 
+  /// \brief Returns the <result-id> for a SPIR-V builtin variable.
+  uint32_t getBuiltinVar(spv::BuiltIn builtIn);
+
   /// \brief Creates the stage output variables by parsing the semantics
   /// attached to the given function's parameter or return value and returns
   /// true on success. SPIR-V instructions will also be generated to update the
@@ -648,6 +651,15 @@ private:
   /// to the <type-id>
   llvm::DenseMap<const DeclContext *, uint32_t> ctBufferPCTypeIds;
 
+  /// <result-id> for the SPIR-V builtin variables accessed by
+  /// WaveGetLaneCount() and WaveGetLaneIndex().
+  ///
+  /// These are the only two cases that SPIR-V builtin variables are accessed
+  /// using HLSL intrinsic function calls. All other builtin variables are
+  /// accessed using stage IO variables.
+  uint32_t laneCountBuiltinId;
+  uint32_t laneIndexBuiltinId;
+
   /// Whether the translated SPIR-V binary needs legalization.
   ///
   /// The following cases will require legalization:
@@ -718,7 +730,7 @@ DeclResultIdMapper::DeclResultIdMapper(const hlsl::ShaderModel &model,
     : shaderModel(model), theBuilder(builder), spirvOptions(options),
       astContext(context), diags(context.getDiagnostics()),
       typeTranslator(context, builder, diags, options), entryFunctionId(0),
-      needsLegalization(false),
+      laneCountBuiltinId(0), laneIndexBuiltinId(0), needsLegalization(false),
       glPerVertex(model, context, builder, typeTranslator, options.invertY) {}
 
 bool DeclResultIdMapper::decorateStageIOLocations() {

+ 14 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -5895,6 +5895,20 @@ SpirvEvalInfo SPIRVEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
   case hlsl::IntrinsicOp::IOP_f32tof16:
     retVal = processIntrinsicF32ToF16(callExpr);
     break;
+  case hlsl::IntrinsicOp::IOP_WaveGetLaneCount: {
+    const uint32_t retType =
+        typeTranslator.translateType(callExpr->getCallReturnType(astContext));
+    const uint32_t varId =
+        declIdMapper.getBuiltinVar(spv::BuiltIn::SubgroupSize);
+    retVal = theBuilder.createLoad(retType, varId);
+  } break;
+  case hlsl::IntrinsicOp::IOP_WaveGetLaneIndex: {
+    const uint32_t retType =
+        typeTranslator.translateType(callExpr->getCallReturnType(astContext));
+    const uint32_t varId =
+        declIdMapper.getBuiltinVar(spv::BuiltIn::SubgroupLocalInvocationId);
+    retVal = theBuilder.createLoad(retType, varId);
+  } break;
   case hlsl::IntrinsicOp::IOP_abort:
   case hlsl::IntrinsicOp::IOP_GetRenderTargetSampleCount:
   case hlsl::IntrinsicOp::IOP_GetRenderTargetSamplePosition: {

+ 19 - 0
tools/clang/test/CodeGenSPIRV/sm6.wave-get-lane-count.hlsl

@@ -0,0 +1,19 @@
+// Run: %dxc -T cs_6_0 -E main
+
+RWStructuredBuffer<uint> values;
+
+// CHECK: OpCapability SubgroupBallotKHR
+// CHECK: OpExtension "SPV_KHR_shader_ballot"
+
+// CHECK: OpEntryPoint GLCompute
+// CHECK-SAME: %SubgroupSize
+
+// CHECK: OpDecorate %SubgroupSize BuiltIn SubgroupSize
+
+// CHECK: %SubgroupSize = OpVariable %_ptr_Input_uint Input
+
+[numthreads(32, 1, 1)]
+void main(uint3 id: SV_DispatchThreadID) {
+// CHECK: OpLoad %uint %SubgroupSize
+    values[id.x] = WaveGetLaneCount();
+}

+ 19 - 0
tools/clang/test/CodeGenSPIRV/sm6.wave-get-lane-index.hlsl

@@ -0,0 +1,19 @@
+// Run: %dxc -T cs_6_0 -E main
+
+RWStructuredBuffer<uint> values;
+
+// CHECK: OpCapability SubgroupBallotKHR
+// CHECK: OpExtension "SPV_KHR_shader_ballot"
+
+// CHECK: OpEntryPoint GLCompute
+// CHECK-SAME: %SubgroupLocalInvocationId
+
+// CHECK: OpDecorate %SubgroupLocalInvocationId BuiltIn SubgroupLocalInvocationId
+
+// CHECK: %SubgroupLocalInvocationId = OpVariable %_ptr_Input_uint Input
+
+[numthreads(32, 1, 1)]
+void main(uint3 id: SV_DispatchThreadID) {
+// CHECK: OpLoad %uint %SubgroupLocalInvocationId
+    values[id.x] = WaveGetLaneIndex();
+}

+ 27 - 0
tools/clang/test/CodeGenSPIRV/sm6.wave.builtin.no-dup.hlsl

@@ -0,0 +1,27 @@
+// Run: %dxc -T cs_6_0 -E main
+
+// Some wave ops translates into SPIR-V builtin variables.
+// Test that we are not generating duplicated builtins for multiple calls of
+// of the same wave ops.
+RWStructuredBuffer<uint> values;
+
+// CHECK: OpEntryPoint GLCompute
+// CHECK-SAME: %SubgroupSize %SubgroupLocalInvocationId
+
+// CHECK: OpDecorate %SubgroupSize BuiltIn SubgroupSize
+// CHECK-NOT: OpDecorate {{%\w+}} BuiltIn SubgroupSize
+
+// CHECK: OpDecorate %SubgroupLocalInvocationId BuiltIn SubgroupLocalInvocationId
+// CHECK-NOT: OpDecorate {{%\w+}} BuiltIn SubgroupLocalInvocationId
+
+// CHECK: %SubgroupSize = OpVariable %_ptr_Input_uint Input
+// CHECK-NEXT: %SubgroupLocalInvocationId = OpVariable %_ptr_Input_uint Input
+
+[numthreads(32, 1, 1)]
+void main(uint3 id: SV_DispatchThreadID) {
+// CHECK: OpLoad %uint %SubgroupSize
+// CHECK: OpLoad %uint %SubgroupSize
+// CHECK: OpLoad %uint %SubgroupLocalInvocationId
+// CHECK: OpLoad %uint %SubgroupLocalInvocationId
+    values[id.x] = WaveGetLaneCount() + WaveGetLaneCount() + WaveGetLaneIndex() + WaveGetLaneIndex();
+}

+ 11 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -969,6 +969,17 @@ TEST_F(FileTest, PrimitiveErrorGS) {
   runFileTest("primitive.error.gs.hlsl", Expect::Failure);
 }
 
+// Shader model 6.0 wave query
+TEST_F(FileTest, SM6WaveGetLaneCount) {
+  runFileTest("sm6.wave-get-lane-count.hlsl");
+}
+TEST_F(FileTest, SM6WaveGetLaneIndex) {
+  runFileTest("sm6.wave-get-lane-index.hlsl");
+}
+TEST_F(FileTest, SM6WaveBuiltInNoDuplicate) {
+  runFileTest("sm6.wave.builtin.no-dup.hlsl");
+}
+
 // SPIR-V specific
 TEST_F(FileTest, SpirvStorageClass) { runFileTest("spirv.storage-class.hlsl"); }