Ver Fonte

[spirv] Add support for WaveReadLineFirst() (#1106)

Lei Zhang há 7 anos atrás
pai
commit
c1ea245a16

+ 6 - 0
docs/SPIR-V.rst

@@ -2490,6 +2490,12 @@ according to the following table:
 ``WaveGetLaneIndex()`` ``SubgroupLocalInvocationId`` ``SPV_KHR_shader_ballot``
 ``WaveGetLaneIndex()`` ``SubgroupLocalInvocationId`` ``SPV_KHR_shader_ballot``
 ====================== ============================= =========================
 ====================== ============================= =========================
 
 
+======================= ================================ =========================
+      Intrinsic               SPIR-V Instruction                Extension
+======================= ================================ =========================
+``WaveReadLaneFirst()`` ``OpSubgroupFirstInvocationKHR`` ``SPV_KHR_shader_ballot``
+======================= ================================ =========================
+
 Vulkan Command-line Options
 Vulkan Command-line Options
 ===========================
 ===========================
 
 

+ 3 - 0
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -303,6 +303,9 @@ public:
   /// \brief Creates an OpEndPrimitive instruction.
   /// \brief Creates an OpEndPrimitive instruction.
   void createEndPrimitive();
   void createEndPrimitive();
 
 
+  /// \brief Creates an OpSubgroupFirstInvocationKHR instruciton.
+  uint32_t createSubgroupFirstInvocation(uint32_t resultType, uint32_t value);
+
   // === SPIR-V Module Structure ===
   // === SPIR-V Module Structure ===
 
 
   inline void requireCapability(spv::Capability);
   inline void requireCapability(spv::Capability);

+ 12 - 0
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -705,6 +705,18 @@ void ModuleBuilder::createEndPrimitive() {
   insertPoint->appendInstruction(std::move(constructSite));
   insertPoint->appendInstruction(std::move(constructSite));
 }
 }
 
 
+uint32_t ModuleBuilder::createSubgroupFirstInvocation(uint32_t resultType,
+                                                      uint32_t value) {
+  assert(insertPoint && "null insert point");
+  addExtension("SPV_KHR_shader_ballot");
+  requireCapability(spv::Capability::SubgroupBallotKHR);
+
+  uint32_t resultId = theContext.takeNextId();
+  instBuilder.opSubgroupFirstInvocationKHR(resultType, resultId, value).x();
+  insertPoint->appendInstruction(std::move(constructSite));
+  return resultId;
+}
+
 void ModuleBuilder::addExecutionMode(uint32_t entryPointId,
 void ModuleBuilder::addExecutionMode(uint32_t entryPointId,
                                      spv::ExecutionMode em,
                                      spv::ExecutionMode em,
                                      llvm::ArrayRef<uint32_t> params) {
                                      llvm::ArrayRef<uint32_t> params) {

+ 11 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -6018,6 +6018,17 @@ SpirvEvalInfo SPIRVEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
         declIdMapper.getBuiltinVar(spv::BuiltIn::SubgroupLocalInvocationId);
         declIdMapper.getBuiltinVar(spv::BuiltIn::SubgroupLocalInvocationId);
     retVal = theBuilder.createLoad(retType, varId);
     retVal = theBuilder.createLoad(retType, varId);
   } break;
   } break;
+  case hlsl::IntrinsicOp::IOP_WaveReadLaneFirst: {
+    const auto retType = callExpr->getCallReturnType(astContext);
+    if (!retType->isScalarType()) {
+      emitError("vector overloads of WaveReadLaneFirst unimplemented",
+                callExpr->getExprLoc());
+      return 0;
+    }
+    const uint32_t retTypeId = typeTranslator.translateType(retType);
+    retVal = theBuilder.createSubgroupFirstInvocation(
+        retTypeId, doExpr(callExpr->getArg(0)));
+  } break;
   case hlsl::IntrinsicOp::IOP_abort:
   case hlsl::IntrinsicOp::IOP_abort:
   case hlsl::IntrinsicOp::IOP_GetRenderTargetSampleCount:
   case hlsl::IntrinsicOp::IOP_GetRenderTargetSampleCount:
   case hlsl::IntrinsicOp::IOP_GetRenderTargetSamplePosition: {
   case hlsl::IntrinsicOp::IOP_GetRenderTargetSamplePosition: {

+ 30 - 0
tools/clang/test/CodeGenSPIRV/sm6.wave-read-lane-first.hlsl

@@ -0,0 +1,30 @@
+// Run: %dxc -T cs_6_0 -E main
+
+// CHECK: OpCapability SubgroupBallotKHR
+// CHECK: OpExtension "SPV_KHR_shader_ballot"
+
+struct S {
+    uint4 val1;
+     int2 val2;
+    float val3;
+};
+
+RWStructuredBuffer<S> values;
+
+[numthreads(32, 1, 1)]
+void main(uint3 id: SV_DispatchThreadID) {
+    uint x = id.x;
+
+    uint4 val1 = values[x].val1;
+     int2 val2 = values[x].val2;
+    float val3 = values[x].val3;
+
+// OpSubgroupFirstInvocationKHR requires that:
+//   Result Type must be a 32-bit integer type or a 32-bit float type scalar.
+
+    // values[x].val1 = WaveReadLaneFirst(val1);
+    // values[x].val2 = WaveReadLaneFirst(val2);
+// CHECK:      [[val3:%\d+]] = OpLoad %float %val3
+// CHECK-NEXT:      {{%\d+}} = OpSubgroupFirstInvocationKHR %float [[val3]]
+    values[x].val3 = WaveReadLaneFirst(val3);
+}

+ 5 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -986,6 +986,11 @@ TEST_F(FileTest, SM6WaveBuiltInNoDuplicate) {
   runFileTest("sm6.wave.builtin.no-dup.hlsl");
   runFileTest("sm6.wave.builtin.no-dup.hlsl");
 }
 }
 
 
+// Shader model 6.0 wave broadcast
+TEST_F(FileTest, SM6WaveReadLaneFirst) {
+  runFileTest("sm6.wave-read-lane-first.hlsl");
+}
+
 // SPIR-V specific
 // SPIR-V specific
 TEST_F(FileTest, SpirvStorageClass) { runFileTest("spirv.storage-class.hlsl"); }
 TEST_F(FileTest, SpirvStorageClass) { runFileTest("spirv.storage-class.hlsl"); }