浏览代码

[SPIR-V]Support KHR_Ray_tracing terminate Ops (#3295)

* [SPIR-V]Support KHR_Ray_tracing terminate Ops

Add OpIgnoreIntersectionKHR/OpTerminateRayKHR related
https://github.com/microsoft/DirectXShaderCompiler/issues/3285

* Remove redudant message in the header of check hlsl

* rebase against master (fix merge conflict)

* Create a new basic block after termination.

* Don't change the SPIRV-Headers hash.

* Add unit tests and add assertion.

* Fix missing ==

Co-authored-by: Ehsan Nasiri <[email protected]>
JiaoluAMD 5 年之前
父节点
当前提交
ce2ec2f654

+ 4 - 0
tools/clang/include/clang/SPIRV/SpirvBuilder.h

@@ -499,6 +499,10 @@ public:
   /// \brief Creates an OpReadClockKHR instruction.
   /// \brief Creates an OpReadClockKHR instruction.
   SpirvInstruction *createReadClock(SpirvInstruction *scope, SourceLocation);
   SpirvInstruction *createReadClock(SpirvInstruction *scope, SourceLocation);
 
 
+  /// \brief Create Raytracing terminate Ops
+  /// OpIgnoreIntersectionKHR/OpTerminateIntersectionKHR
+  void createRaytracingTerminateKHR(spv::Op opcode, SourceLocation loc);
+
   // === SPIR-V Module Structure ===
   // === SPIR-V Module Structure ===
   inline void setMemoryModel(spv::AddressingModel, spv::MemoryModel);
   inline void setMemoryModel(spv::AddressingModel, spv::MemoryModel);
 
 

+ 23 - 7
tools/clang/include/clang/SPIRV/SpirvInstruction.h

@@ -77,12 +77,13 @@ public:
 
 
     // The following section is for termination instructions.
     // The following section is for termination instructions.
     // Used by LLVM-style RTTI; order matters.
     // Used by LLVM-style RTTI; order matters.
-    IK_Branch,            // OpBranch
-    IK_BranchConditional, // OpBranchConditional
-    IK_Kill,              // OpKill
-    IK_Return,            // OpReturn*
-    IK_Switch,            // OpSwitch
-    IK_Unreachable,       // OpUnreachable
+    IK_Branch,              // OpBranch
+    IK_BranchConditional,   // OpBranchConditional
+    IK_Kill,                // OpKill
+    IK_Return,              // OpReturn*
+    IK_Switch,              // OpSwitch
+    IK_Unreachable,         // OpUnreachable
+    IK_RayTracingTerminate, // OpIgnoreIntersectionKHR/OpTerminateRayKHR
 
 
     // Normal instruction kinds
     // Normal instruction kinds
     // In alphabetical order
     // In alphabetical order
@@ -634,6 +635,7 @@ private:
 ///
 ///
 /// * OpBranch, OpBranchConditional, OpSwitch
 /// * OpBranch, OpBranchConditional, OpSwitch
 /// * OpReturn, OpReturnValue, OpKill, OpUnreachable
 /// * OpReturn, OpReturnValue, OpKill, OpUnreachable
+/// * OpIgnoreIntersectionKHR, OpTerminateIntersectionKHR
 ///
 ///
 /// The first group (branching instructions) also include information on
 /// The first group (branching instructions) also include information on
 /// possible branches that will be taken next.
 /// possible branches that will be taken next.
@@ -641,7 +643,8 @@ class SpirvTerminator : public SpirvInstruction {
 public:
 public:
   // For LLVM-style RTTI
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
   static bool classof(const SpirvInstruction *inst) {
-    return inst->getKind() >= IK_Branch && inst->getKind() <= IK_Unreachable;
+    return inst->getKind() >= IK_Branch &&
+           inst->getKind() <= IK_RayTracingTerminate;
   }
   }
 
 
 protected:
 protected:
@@ -1958,6 +1961,19 @@ private:
   bool cullFlags;
   bool cullFlags;
 };
 };
 
 
+class SpirvRayTracingTerminateOpKHR : public SpirvTerminator {
+public:
+  SpirvRayTracingTerminateOpKHR(spv::Op opcode, SourceLocation loc);
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvRayTracingTerminateOpKHR)
+
+  // For LLVM-style RTTI
+  static bool classof(const SpirvInstruction *inst) {
+    return inst->getKind() == IK_RayTracingTerminate;
+  }
+
+  bool invokeVisitor(Visitor *v) override;
+};
+
 /// \brief OpDemoteToHelperInvocationEXT instruction.
 /// \brief OpDemoteToHelperInvocationEXT instruction.
 /// Demote fragment shader invocation to a helper invocation. Any stores to
 /// Demote fragment shader invocation to a helper invocation. Any stores to
 /// memory after this instruction are suppressed and the fragment does not write
 /// memory after this instruction are suppressed and the fragment does not write

+ 1 - 1
tools/clang/include/clang/SPIRV/SpirvVisitor.h

@@ -138,7 +138,7 @@ public:
 
 
   DEFINE_VISIT_METHOD(SpirvRayQueryOpKHR)
   DEFINE_VISIT_METHOD(SpirvRayQueryOpKHR)
   DEFINE_VISIT_METHOD(SpirvReadClock)
   DEFINE_VISIT_METHOD(SpirvReadClock)
-
+  DEFINE_VISIT_METHOD(SpirvRayTracingTerminateOpKHR)
 #undef DEFINE_VISIT_METHOD
 #undef DEFINE_VISIT_METHOD
 
 
 protected:
 protected:

+ 6 - 0
tools/clang/lib/SPIRV/EmitVisitor.cpp

@@ -1645,6 +1645,12 @@ bool EmitVisitor::visit(SpirvReadClock *inst) {
   return true;
   return true;
 }
 }
 
 
+bool EmitVisitor::visit(SpirvRayTracingTerminateOpKHR *inst) {
+  initInstruction(inst);
+  finalizeInstruction(&mainBinary);
+  return true;
+}
+
 // EmitTypeHandler ------
 // EmitTypeHandler ------
 
 
 void EmitTypeHandler::initTypeInstruction(spv::Op op) {
 void EmitTypeHandler::initTypeInstruction(spv::Op op) {

+ 1 - 1
tools/clang/lib/SPIRV/EmitVisitor.h

@@ -267,7 +267,7 @@ public:
   bool visit(SpirvDemoteToHelperInvocationEXT *) override;
   bool visit(SpirvDemoteToHelperInvocationEXT *) override;
   bool visit(SpirvRayQueryOpKHR *) override;
   bool visit(SpirvRayQueryOpKHR *) override;
   bool visit(SpirvReadClock *) override;
   bool visit(SpirvReadClock *) override;
-
+  bool visit(SpirvRayTracingTerminateOpKHR *) override;
   bool visit(SpirvDebugInfoNone *) override;
   bool visit(SpirvDebugInfoNone *) override;
   bool visit(SpirvDebugSource *) override;
   bool visit(SpirvDebugSource *) override;
   bool visit(SpirvDebugCompilationUnit *) override;
   bool visit(SpirvDebugCompilationUnit *) override;

+ 7 - 0
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -939,6 +939,13 @@ SpirvInstruction *SpirvBuilder::createReadClock(SpirvInstruction *scope,
   return inst;
   return inst;
 }
 }
 
 
+void SpirvBuilder::createRaytracingTerminateKHR(spv::Op opcode,
+                                                SourceLocation loc) {
+  assert(insertPoint && "null insert point");
+  auto *inst = new (context) SpirvRayTracingTerminateOpKHR(opcode, loc);
+  insertPoint->addInstruction(inst);
+}
+
 void SpirvBuilder::addModuleProcessed(llvm::StringRef process) {
 void SpirvBuilder::addModuleProcessed(llvm::StringRef process) {
   mod->addModuleProcessed(new (context) SpirvModuleProcessed({}, process));
   mod->addModuleProcessed(new (context) SpirvModuleProcessed({}, process));
 }
 }

+ 25 - 5
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -7488,11 +7488,31 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
                                callExpr->getExprLoc());
                                callExpr->getExprLoc());
       }
       }
     }
     }
-    spvBuilder.createRayTracingOpsNV(
-        hlslOpcode == hlsl::IntrinsicOp ::IOP_AcceptHitAndEndSearch
-            ? spv::Op::OpTerminateRayNV
-            : spv::Op::OpIgnoreIntersectionNV,
-        QualType(), {}, srcLoc);
+    bool nvRayTracing =
+        featureManager.isExtensionEnabled(Extension::NV_ray_tracing);
+
+    if (nvRayTracing) {
+      spvBuilder.createRayTracingOpsNV(
+          hlslOpcode == hlsl::IntrinsicOp::IOP_AcceptHitAndEndSearch
+              ? spv::Op::OpTerminateRayNV
+              : spv::Op::OpIgnoreIntersectionNV,
+          QualType(), {}, srcLoc);
+    } else {
+      spvBuilder.createRaytracingTerminateKHR(
+          hlslOpcode == hlsl::IntrinsicOp::IOP_AcceptHitAndEndSearch
+              ? spv::Op::OpTerminateRayKHR
+              : spv::Op::OpIgnoreIntersectionKHR,
+          srcLoc);
+      // According to the SPIR-V spec, both OpTerminateRayKHR and
+      // OpIgnoreIntersectionKHR are termination instructions.
+      // The spec also requires that these instructions must be the last
+      // instruction in a block.
+      // Therefore we need to create a new basic block, and the following
+      // instructions will go there.
+      auto *newBB = spvBuilder.createBasicBlock();
+      spvBuilder.setInsertPoint(newBB);
+    }
+
     break;
     break;
   }
   }
   case hlsl::IntrinsicOp::IOP_ReportHit: {
   case hlsl::IntrinsicOp::IOP_ReportHit: {

+ 8 - 0
tools/clang/lib/SPIRV/SpirvInstruction.cpp

@@ -105,6 +105,7 @@ DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvDebugTypeTemplate)
 DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvDebugTypeTemplateParameter)
 DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvDebugTypeTemplateParameter)
 DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvRayQueryOpKHR)
 DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvRayQueryOpKHR)
 DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvReadClock)
 DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvReadClock)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvRayTracingTerminateOpKHR)
 
 
 #undef DEFINE_INVOKE_VISITOR_FOR_CLASS
 #undef DEFINE_INVOKE_VISITOR_FOR_CLASS
 
 
@@ -987,5 +988,12 @@ SpirvReadClock::SpirvReadClock(QualType resultType, SpirvInstruction *s,
     : SpirvInstruction(IK_ReadClock, spv::Op::OpReadClockKHR, resultType, loc),
     : SpirvInstruction(IK_ReadClock, spv::Op::OpReadClockKHR, resultType, loc),
       scope(s) {}
       scope(s) {}
 
 
+SpirvRayTracingTerminateOpKHR::SpirvRayTracingTerminateOpKHR(spv::Op opcode,
+                                                             SourceLocation loc)
+    : SpirvTerminator(IK_RayTracingTerminate, opcode, loc) {
+  assert(opcode == spv::Op::OpTerminateRayKHR ||
+         opcode == spv::Op::OpIgnoreIntersectionKHR);
+}
+
 } // namespace spirv
 } // namespace spirv
 } // namespace clang
 } // namespace clang

+ 53 - 0
tools/clang/test/CodeGenSPIRV/raytracing.khr.terminate.hlsl

@@ -0,0 +1,53 @@
+// Run: %dxc -T lib_6_3 -fspv-target-env=vulkan1.2
+// CHECK:  OpCapability RayTracingKHR
+// CHECK:  OpExtension "SPV_KHR_ray_tracing"
+
+// CHECK:  OpDecorate [[l:%\d+]] BuiltIn HitKindNV
+
+// CHECK:  OpTypePointer IncomingRayPayloadNV %Payload
+struct Payload
+{
+  float4 color;
+};
+// CHECK:  OpTypePointer HitAttributeNV %Attribute
+struct Attribute
+{
+  float2 bary;
+};
+
+[shader("anyhit")]
+void MyAHitMain(inout Payload MyPayload, in Attribute MyAttr) {
+
+// CHECK:  OpLoad %uint [[l]]
+  uint _16 = HitKind();
+
+// CHECK: %if_true = OpLabel
+  if (_16 == 1U) {
+// CHECK: OpIgnoreIntersectionKHR
+    IgnoreHit();
+// CHECK-NOT: OpLoad %uint %_16
+// CHECK-NOT: OpStore
+    uint a = _16;
+// CHECK-NEXT: %if_false = OpLabel
+  } else {
+// CHECK: OpTerminateRayKHR
+    AcceptHitAndEndSearch();
+// CHECK-NOT: OpLoad %uint %_16
+// CHECK-NOT: OpStore
+    uint b = _16;
+  }
+// CHECK-NEXT: %if_merge = OpLabel
+// CHECK-NEXT: OpReturn
+// CHECK-NEXT: OpFunctionEnd
+}
+
+
+[shader("anyhit")]
+void MyAHitMain2(inout Payload MyPayload, in Attribute MyAttr) {
+// CHECK: OpTerminateRayKHR
+    AcceptHitAndEndSearch();
+// CHECK-NOT: OpAccessChain
+// CHECK-NOT: OpStore
+    MyPayload.color = 0.xxxx;
+// CHECK-NEXT: OpFunctionEnd
+}

+ 4 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -2244,6 +2244,10 @@ TEST_F(FileTest, RayTracingAccelerationStructure) {
   runFileTest("raytracing.acceleration-structure.hlsl");
   runFileTest("raytracing.acceleration-structure.hlsl");
 }
 }
 
 
+TEST_F(FileTest, RayTracingTerminate) {
+  runFileTest("raytracing.khr.terminate.hlsl");
+}
+
 // For decoration uniqueness
 // For decoration uniqueness
 TEST_F(FileTest, DecorationUnique) { runFileTest("decoration.unique.hlsl"); }
 TEST_F(FileTest, DecorationUnique) { runFileTest("decoration.unique.hlsl"); }
 
 

+ 18 - 0
tools/clang/unittests/SPIRV/SpirvBasicBlockTest.cpp

@@ -97,6 +97,24 @@ TEST_F(SpirvBasicBlockTest, CheckTerminatedByUnreachable) {
   EXPECT_TRUE(bb.hasTerminator());
   EXPECT_TRUE(bb.hasTerminator());
 }
 }
 
 
+TEST_F(SpirvBasicBlockTest, CheckTerminatedByTerminateRay) {
+  SpirvBasicBlock bb("bb");
+  SpirvContext &context = getSpirvContext();
+  auto *khrTerminateRay = new (context)
+      SpirvRayTracingTerminateOpKHR(spv::Op::OpTerminateRayKHR, {});
+  bb.addInstruction(khrTerminateRay);
+  EXPECT_TRUE(bb.hasTerminator());
+}
+
+TEST_F(SpirvBasicBlockTest, CheckTerminatedByIgnoreIntersection) {
+  SpirvBasicBlock bb("bb");
+  SpirvContext &context = getSpirvContext();
+  auto *khrIgnoreIntersection = new (context)
+      SpirvRayTracingTerminateOpKHR(spv::Op::OpIgnoreIntersectionKHR, {});
+  bb.addInstruction(khrIgnoreIntersection);
+  EXPECT_TRUE(bb.hasTerminator());
+}
+
 TEST_F(SpirvBasicBlockTest, CheckNotTerminated) {
 TEST_F(SpirvBasicBlockTest, CheckNotTerminated) {
   SpirvBasicBlock bb("bb");
   SpirvBasicBlock bb("bb");
   SpirvContext &context = getSpirvContext();
   SpirvContext &context = getSpirvContext();