Ver código fonte

[spirv] Support SPV_EXT_demote_to_helper_invocation (#2773)

* [spirv] Support SPV_EXT_demote_to_helper_invocation

* address comments.
Ehsan 5 anos atrás
pai
commit
f6057fbee2

+ 1 - 0
tools/clang/include/clang/SPIRV/FeatureManager.h

@@ -35,6 +35,7 @@ enum class Extension {
   KHR_shader_draw_parameters,
   KHR_post_depth_coverage,
   KHR_ray_tracing,
+  EXT_demote_to_helper_invocation,
   EXT_descriptor_indexing,
   EXT_fragment_fully_covered,
   EXT_fragment_invocation_density,

+ 3 - 0
tools/clang/include/clang/SPIRV/SpirvBuilder.h

@@ -412,6 +412,9 @@ public:
                         llvm::ArrayRef<SpirvInstruction *> operands,
                         SourceLocation loc);
 
+  /// \brief Creates an OpDemoteToHelperInvocationEXT instruction.
+  SpirvInstruction *createDemoteToHelperInvocationEXT(SourceLocation);
+
   // === SPIR-V Module Structure ===
   inline void setMemoryModel(spv::AddressingModel, spv::MemoryModel);
 

+ 21 - 0
tools/clang/include/clang/SPIRV/SpirvInstruction.h

@@ -117,6 +117,8 @@ public:
     IK_VectorShuffle,             // OpVectorShuffle
     IK_ArrayLength,               // OpArrayLength
     IK_RayTracingOpNV,            // NV raytracing ops
+
+    IK_DemoteToHelperInvocationEXT, // OpDemoteToHelperInvocationEXT
   };
 
   virtual ~SpirvInstruction() = default;
@@ -1750,6 +1752,25 @@ private:
   llvm::SmallVector<SpirvInstruction *, 4> operands;
 };
 
+/// \brief OpDemoteToHelperInvocationEXT instruction.
+/// Demote fragment shader invocation to a helper invocation. Any stores to
+/// memory after this instruction are suppressed and the fragment does not write
+/// outputs to the framebuffer. Unlike the OpKill instruction, this does not
+/// necessarily terminate the invocation. It is not considered a flow control
+/// instruction (flow control does not become non-uniform) and does not
+/// terminate the block.
+class SpirvDemoteToHelperInvocationEXT : public SpirvInstruction {
+public:
+  SpirvDemoteToHelperInvocationEXT(SourceLocation);
+
+  // For LLVM-style RTTI
+  static bool classof(const SpirvInstruction *inst) {
+    return inst->getKind() == IK_DemoteToHelperInvocationEXT;
+  }
+
+  bool invokeVisitor(Visitor *v) override;
+};
+
 #undef DECLARE_INVOKE_VISITOR_FOR_CLASS
 
 } // namespace spirv

+ 1 - 0
tools/clang/include/clang/SPIRV/SpirvVisitor.h

@@ -113,6 +113,7 @@ public:
   DEFINE_VISIT_METHOD(SpirvVectorShuffle)
   DEFINE_VISIT_METHOD(SpirvArrayLength)
   DEFINE_VISIT_METHOD(SpirvRayTracingOpNV)
+  DEFINE_VISIT_METHOD(SpirvDemoteToHelperInvocationEXT)
 
 #undef DEFINE_VISIT_METHOD
 

+ 8 - 0
tools/clang/lib/SPIRV/CapabilityVisitor.cpp

@@ -553,5 +553,13 @@ bool CapabilityVisitor::visit(SpirvExtInst *instr) {
   return visitInstruction(instr);
 }
 
+bool CapabilityVisitor::visit(SpirvDemoteToHelperInvocationEXT *inst) {
+  addCapability(spv::Capability::DemoteToHelperInvocationEXT,
+                inst->getSourceLocation());
+  addExtension(Extension::EXT_demote_to_helper_invocation, "discard",
+               inst->getSourceLocation());
+  return true;
+}
+
 } // end namespace spirv
 } // end namespace clang

+ 1 - 0
tools/clang/lib/SPIRV/CapabilityVisitor.h

@@ -33,6 +33,7 @@ public:
   bool visit(SpirvImageOp *);
   bool visit(SpirvImageSparseTexelsResident *);
   bool visit(SpirvExtInst *);
+  bool visit(SpirvDemoteToHelperInvocationEXT *);
 
   /// The "sink" visit function for all instructions.
   ///

+ 6 - 0
tools/clang/lib/SPIRV/EmitVisitor.cpp

@@ -1120,6 +1120,12 @@ bool EmitVisitor::visit(SpirvRayTracingOpNV *inst) {
   return true;
 }
 
+bool EmitVisitor::visit(SpirvDemoteToHelperInvocationEXT *inst) {
+  initInstruction(inst);
+  finalizeInstruction();
+  return true;
+}
+
 // EmitTypeHandler ------
 
 void EmitTypeHandler::initTypeInstruction(spv::Op op) {

+ 1 - 0
tools/clang/lib/SPIRV/EmitVisitor.h

@@ -258,6 +258,7 @@ public:
   bool visit(SpirvVectorShuffle *);
   bool visit(SpirvArrayLength *);
   bool visit(SpirvRayTracingOpNV *);
+  bool visit(SpirvDemoteToHelperInvocationEXT *);
 
   // Returns the assembled binary built up in this visitor.
   std::vector<uint32_t> takeBinary();

+ 5 - 2
tools/clang/lib/SPIRV/FeatureManager.cpp

@@ -107,6 +107,8 @@ Extension FeatureManager::getExtensionSymbol(llvm::StringRef name) {
       .Case("SPV_KHR_shader_draw_parameters",
             Extension::KHR_shader_draw_parameters)
       .Case("SPV_KHR_ray_tracing", Extension::KHR_ray_tracing)
+      .Case("SPV_EXT_demote_to_helper_invocation",
+            Extension::EXT_demote_to_helper_invocation)
       .Case("SPV_EXT_descriptor_indexing", Extension::EXT_descriptor_indexing)
       .Case("SPV_EXT_fragment_fully_covered",
             Extension::EXT_fragment_fully_covered)
@@ -122,8 +124,7 @@ Extension FeatureManager::getExtensionSymbol(llvm::StringRef name) {
             Extension::AMD_shader_explicit_vertex_parameter)
       .Case("SPV_GOOGLE_hlsl_functionality1",
             Extension::GOOGLE_hlsl_functionality1)
-      .Case("SPV_GOOGLE_user_type",
-            Extension::GOOGLE_user_type)
+      .Case("SPV_GOOGLE_user_type", Extension::GOOGLE_user_type)
       .Case("SPV_KHR_post_depth_coverage", Extension::KHR_post_depth_coverage)
       .Case("SPV_NV_ray_tracing", Extension::NV_ray_tracing)
       .Case("SPV_NV_mesh_shader", Extension::NV_mesh_shader)
@@ -146,6 +147,8 @@ const char *FeatureManager::getExtensionName(Extension symbol) {
     return "SPV_KHR_post_depth_coverage";
   case Extension::KHR_ray_tracing:
     return "SPV_KHR_ray_tracing";
+  case Extension::EXT_demote_to_helper_invocation:
+    return "SPV_EXT_demote_to_helper_invocation";
   case Extension::EXT_descriptor_indexing:
     return "SPV_EXT_descriptor_indexing";
   case Extension::EXT_fragment_fully_covered:

+ 8 - 0
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -769,6 +769,14 @@ SpirvBuilder::createRayTracingOpsNV(spv::Op opcode, QualType resultType,
   return inst;
 }
 
+SpirvInstruction *
+SpirvBuilder::createDemoteToHelperInvocationEXT(SourceLocation loc) {
+  assert(insertPoint && "null insert point");
+  auto *inst = new (context) SpirvDemoteToHelperInvocationEXT(loc);
+  insertPoint->addInstruction(inst);
+  return inst;
+}
+
 void SpirvBuilder::addModuleProcessed(llvm::StringRef process) {
   mod->addModuleProcessed(new (context) SpirvModuleProcessed({}, process));
 }

+ 16 - 6
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -1398,12 +1398,22 @@ spv::LoopControlMask SpirvEmitter::translateLoopAttribute(const Stmt *stmt,
 
 void SpirvEmitter::doDiscardStmt(const DiscardStmt *discardStmt) {
   assert(!spvBuilder.isCurrentBasicBlockTerminated());
-  spvBuilder.createKill(discardStmt->getLoc());
-  // Some statements that alter the control flow (break, continue, return, and
-  // discard), require creation of a new basic block to hold any statement that
-  // may follow them.
-  auto *newBB = spvBuilder.createBasicBlock();
-  spvBuilder.setInsertPoint(newBB);
+
+  // The discard statement can only be called from a pixel shader
+  if (!spvContext.isPS()) {
+    emitError("discard statement may only be used in pixel shaders",
+              discardStmt->getLoc());
+    return;
+  }
+
+  // SPV_EXT_demote_to_helper_invocation SPIR-V extension provides a new
+  // instruction OpDemoteToHelperInvocationEXT allowing shaders to "demote" a
+  // fragment shader invocation to behave like a helper invocation for its
+  // duration. The demoted invocation will have no further side effects and will
+  // not output to the framebuffer, but remains active and can participate in
+  // computing derivatives and in subgroup operations. This is a better match
+  // for the "discard" instruction in HLSL.
+  spvBuilder.createDemoteToHelperInvocationEXT(discardStmt->getLoc());
 }
 
 void SpirvEmitter::doDoStmt(const DoStmt *theDoStmt,

+ 8 - 0
tools/clang/lib/SPIRV/SpirvInstruction.cpp

@@ -81,6 +81,7 @@ DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvUnaryOp)
 DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvVectorShuffle)
 DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvArrayLength)
 DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvRayTracingOpNV)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvDemoteToHelperInvocationEXT)
 
 #undef DEFINE_INVOKE_VISITOR_FOR_CLASS
 
@@ -758,5 +759,12 @@ SpirvRayTracingOpNV::SpirvRayTracingOpNV(
     llvm::ArrayRef<SpirvInstruction *> vecOperands, SourceLocation loc)
     : SpirvInstruction(IK_RayTracingOpNV, opcode, resultType, loc),
       operands(vecOperands.begin(), vecOperands.end()) {}
+
+SpirvDemoteToHelperInvocationEXT::SpirvDemoteToHelperInvocationEXT(
+    SourceLocation loc)
+    : SpirvInstruction(IK_DemoteToHelperInvocationEXT,
+                       spv::Op::OpDemoteToHelperInvocationEXT, /*QualType*/ {},
+                       loc) {}
+
 } // namespace spirv
 } // namespace clang

+ 23 - 0
tools/clang/test/CodeGenSPIRV/cf.discard.cs.hlsl

@@ -0,0 +1,23 @@
+// Run: %dxc -T cs_6_0 -E main
+
+// According to the HLS spec, discard can only be called from a pixel shader.
+
+
+[numthreads(32, 1, 1)]
+void main(uint3 id: SV_DispatchThreadID) {
+  int a, b;
+  bool cond = true;
+  while(cond) {
+    if(a==b) {
+// CHECK: :13:7: error: discard statement may only be used in pixel shaders
+      discard;
+      break;
+    } else {
+      ++a;
+// CHECK: :18:7: error: discard statement may only be used in pixel shaders
+      discard;
+      continue;
+      --b;
+    }
+  }
+}

+ 11 - 11
tools/clang/test/CodeGenSPIRV/cf.discard.hlsl

@@ -1,7 +1,9 @@
 // Run: %dxc -T ps_6_0 -E main
 
 // According to the HLS spec, discard can only be called from a pixel shader.
-// This translates to OpKill in SPIR-V. OpKill must be the last instruction in a block.
+
+// CHECK: OpCapability DemoteToHelperInvocationEXT
+// CHECK: OpExtension "SPV_EXT_demote_to_helper_invocation"
 
 void main() {
   int a, b;
@@ -11,20 +13,18 @@ void main() {
 // CHECK: %while_body = OpLabel
     if(a==b) {
 // CHECK: %if_true = OpLabel
-// CHECK-NEXT: OpKill
-      {{discard;}}
-      discard;  // No SPIR-V should be emitted for this statement.
-      break;    // No SPIR-V should be emitted for this statement.
+// CHECK: OpDemoteToHelperInvocationEXT
+      discard;
+      break;
     } else {
-// CHECK-NEXT: %if_false = OpLabel
+// CHECK: %if_false = OpLabel
       ++a;
-// CHECK: OpKill
+// CHECK: OpDemoteToHelperInvocationEXT
       discard;
-      continue; // No SPIR-V should be emitted for this statement.
-      --b;      // No SPIR-V should be emitted for this statement.
+      continue;
+      --b;
     }
-// CHECK-NEXT: %if_merge = OpLabel
-
+// CHECK: %if_merge = OpLabel
   }
 // CHECK: %while_merge = OpLabel
 

+ 4 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -502,6 +502,10 @@ TEST_F(FileTest, BreakStmtMixed) { runFileTest("cf.break.mixed.hlsl"); }
 
 // For discard statement
 TEST_F(FileTest, Discard) { runFileTest("cf.discard.hlsl"); }
+TEST_F(FileTest, DiscardCS) {
+  // Using discard is only allowed in pixel shaders.
+  runFileTest("cf.discard.cs.hlsl", Expect::Failure);
+}
 
 // For return statement
 TEST_F(FileTest, EarlyReturn) { runFileTest("cf.return.early.hlsl"); }