Răsfoiți Sursa

[spirv] Fix Wave*CountBits() and improve compilation time (#1574)

Fixed the compilation time regression introduced by loop
unrolling by updating SPIRV-Tools.
Lei Zhang 6 ani în urmă
părinte
comite
8e6b468c4d

+ 1 - 1
external/SPIRV-Headers

@@ -1 +1 @@
-Subproject commit dcf23bdabacc3c54b83b1f9367e7a8adb27f8d87
+Subproject commit d5b2e1255f706ce1f88812217e9a554f299848af

+ 1 - 1
external/SPIRV-Tools

@@ -1 +1 @@
-Subproject commit 9fbcce4ca17de7b2d8f6b322bcd1d43a7d6adc29
+Subproject commit 4b4bd4c53aaa020f7e349aede394d42476b7e3aa

+ 1 - 1
external/googletest

@@ -1 +1 @@
-Subproject commit d25268a55f6f6f38c65a7d1b7b119e33a46d1688
+Subproject commit 440527a61e1c91188195f7de212c63c77e8f0a45

+ 30 - 8
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -6635,9 +6635,7 @@ SpirvEvalInfo SPIRVEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
     retVal = processWaveVote(callExpr, spv::Op::OpGroupNonUniformAllEqual);
     break;
   case hlsl::IntrinsicOp::IOP_WaveActiveCountBits:
-    retVal = processWaveReductionOrPrefix(
-        callExpr, spv::Op::OpGroupNonUniformBallotBitCount,
-        spv::GroupOperation::Reduce);
+    retVal = processWaveCountBits(callExpr, spv::GroupOperation::Reduce);
     break;
   case hlsl::IntrinsicOp::IOP_WaveActiveUSum:
   case hlsl::IntrinsicOp::IOP_WaveActiveSum:
@@ -6665,9 +6663,7 @@ SpirvEvalInfo SPIRVEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
         spv::GroupOperation::ExclusiveScan);
   } break;
   case hlsl::IntrinsicOp::IOP_WavePrefixCountBits:
-    retVal = processWaveReductionOrPrefix(
-        callExpr, spv::Op::OpGroupNonUniformBallotBitCount,
-        spv::GroupOperation::ExclusiveScan);
+    retVal = processWaveCountBits(callExpr, spv::GroupOperation::ExclusiveScan);
     break;
   case hlsl::IntrinsicOp::IOP_WaveReadLaneAt:
   case hlsl::IntrinsicOp::IOP_WaveReadLaneFirst:
@@ -7194,11 +7190,38 @@ spv::Op SPIRVEmitter::translateWaveOp(hlsl::IntrinsicOp op, QualType type,
   return spv::Op::OpNop;
 }
 
+uint32_t SPIRVEmitter::processWaveCountBits(const CallExpr *callExpr,
+                                            spv::GroupOperation groupOp) {
+  // Signatures:
+  // uint WaveActiveCountBits(bool bBit)
+  // uint WavePrefixCountBits(Bool bBit)
+  assert(callExpr->getNumArgs() == 1);
+
+  featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "Wave Operation",
+                                  callExpr->getExprLoc());
+  theBuilder.requireCapability(getCapabilityForGroupNonUniform(
+      spv::Op::OpGroupNonUniformBallotBitCount));
+
+  const uint32_t predicate = doExpr(callExpr->getArg(0));
+  const uint32_t subgroupScope = theBuilder.getConstantInt32(3);
+
+  const uint32_t u32Type = theBuilder.getUint32Type();
+  const uint32_t v4u32Type = theBuilder.getVecType(u32Type, 4);
+  const uint32_t retType =
+      typeTranslator.translateType(callExpr->getCallReturnType(astContext));
+
+  const uint32_t ballot = theBuilder.createGroupNonUniformUnaryOp(
+      spv::Op::OpGroupNonUniformBallot, v4u32Type, subgroupScope, predicate);
+
+  return theBuilder.createGroupNonUniformUnaryOp(
+      spv::Op::OpGroupNonUniformBallotBitCount, retType, subgroupScope, ballot,
+      llvm::Optional<spv::GroupOperation>(groupOp));
+}
+
 uint32_t SPIRVEmitter::processWaveReductionOrPrefix(
     const CallExpr *callExpr, spv::Op opcode, spv::GroupOperation groupOp) {
   // Signatures:
   // bool WaveActiveAllEqual( <type> expr )
-  // uint WaveActiveCountBits( bool bBit )
   // <type> WaveActiveSum( <type> expr )
   // <type> WaveActiveProduct( <type> expr )
   // <int_type> WaveActiveBitAnd( <int_type> expr )
@@ -7207,7 +7230,6 @@ uint32_t SPIRVEmitter::processWaveReductionOrPrefix(
   // <type> WaveActiveMin( <type> expr)
   // <type> WaveActiveMax( <type> expr)
   //
-  // uint WavePrefixCountBits(Bool bBit)
   // <type> WavePrefixProduct(<type> value)
   // <type> WavePrefixSum(<type> value)
   assert(callExpr->getNumArgs() == 1);

+ 3 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -467,6 +467,9 @@ private:
   /// Processes SM6.0 wave vote intrinsic calls.
   uint32_t processWaveVote(const CallExpr *, spv::Op opcode);
 
+  /// Processes SM6.0 wave active/prefix count bits.
+  uint32_t processWaveCountBits(const CallExpr *, spv::GroupOperation groupOp);
+
   /// Processes SM6.0 wave reduction or scan/prefix intrinsic calls.
   uint32_t processWaveReductionOrPrefix(const CallExpr *, spv::Op op,
                                         spv::GroupOperation groupOp);

+ 5 - 2
tools/clang/test/CodeGenSPIRV/sm6.wave-active-count-bits.hlsl

@@ -7,6 +7,7 @@ struct S {
 };
 
 RWStructuredBuffer<S> values;
+RWStructuredBuffer<S> results;
 
 // CHECK: OpCapability GroupNonUniformBallot
 
@@ -14,7 +15,9 @@ RWStructuredBuffer<S> values;
 void main(uint3 id: SV_DispatchThreadID) {
     uint x = id.x;
 
-// CHECK:  {{%\d+}} = OpGroupNonUniformBallotBitCount %uint %int_3 Reduce {{%\d+}}
-    values[x].val = WaveActiveCountBits(values[x].val == 0);
+// CHECK:         [[cmp:%\d+]] = OpIEqual %bool {{%\d+}} %uint_0
+// CHECK-NEXT: [[ballot:%\d+]] = OpGroupNonUniformBallot %v4uint %int_3 [[cmp]]
+// CHECK:             {{%\d+}} = OpGroupNonUniformBallotBitCount %uint %int_3 Reduce [[ballot]]
+    results[x].val = WaveActiveCountBits(values[x].val == 0);
 }
 

+ 3 - 1
tools/clang/test/CodeGenSPIRV/sm6.wave-prefix-count-bits.hlsl

@@ -14,6 +14,8 @@ RWStructuredBuffer<S> values;
 void main(uint3 id: SV_DispatchThreadID) {
     uint x = id.x;
 
-// CHECK:  {{%\d+}} = OpGroupNonUniformBallotBitCount %uint %int_3 ExclusiveScan {{%\d+}}
+// CHECK:         [[cmp:%\d+]] = OpIEqual %bool {{%\d+}} %uint_0
+// CHECK-NEXT: [[ballot:%\d+]] = OpGroupNonUniformBallot %v4uint %int_3 [[cmp]]
+// CHECK:             {{%\d+}} = OpGroupNonUniformBallotBitCount %uint %int_3 ExclusiveScan [[ballot]]
     values[x].val = WavePrefixCountBits(values[x].val == 0);
 }