Procházet zdrojové kódy

Fix bug in any-hit shaders not writing out payload for early exits. (#2517)

alelenv před 6 roky
rodič
revize
1e52699f2c

+ 24 - 7
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -7098,14 +7098,30 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
     retVal = processRayBuiltins(callExpr, hlslOpcode);
     break;
   }
-  case hlsl::IntrinsicOp::IOP_AcceptHitAndEndSearch: {
-    spvBuilder.createRayTracingOpsNV(spv::Op::OpTerminateRayNV, QualType(), {},
-                                     srcLoc);
-    break;
-  }
+  case hlsl::IntrinsicOp::IOP_AcceptHitAndEndSearch:
   case hlsl::IntrinsicOp::IOP_IgnoreHit: {
-    spvBuilder.createRayTracingOpsNV(spv::Op::OpIgnoreIntersectionNV,
-                                     QualType(), {}, srcLoc);
+
+    // Any modifications made to the ray payload in an any hit shader are
+    // preserved before calling AcceptHit/IgnoreHit. Write out the results to
+    // the payload which is visible only in entry functions
+    const auto iter = functionInfoMap.find(curFunction);
+    if (iter != functionInfoMap.end()) {
+      const auto &entryInfo = iter->second;
+      if (entryInfo->isEntryFunction) {
+        const auto payloadArg = curFunction->getParamDecl(0);
+        const auto payloadArgInst =
+            declIdMapper.getDeclEvalInfo(payloadArg, payloadArg->getLocStart());
+        auto tempLoad = spvBuilder.createLoad(
+            payloadArg->getType(), payloadArgInst, payloadArg->getLocStart());
+        spvBuilder.createStore(currentRayPayload, tempLoad,
+                               callExpr->getExprLoc());
+      }
+    }
+    spvBuilder.createRayTracingOpsNV(
+        hlslOpcode == hlsl::IntrinsicOp ::IOP_AcceptHitAndEndSearch
+            ? spv::Op::OpTerminateRayNV
+            : spv::Op::OpIgnoreIntersectionNV,
+        QualType(), {}, srcLoc);
     break;
   }
   case hlsl::IntrinsicOp::IOP_ReportHit: {
@@ -10357,6 +10373,7 @@ bool SpirvEmitter::emitEntryFunctionWrapperForRayTracing(
         // First argument is always rayPayload
         curStageVar = declIdMapper.createRayTracingNVStageVar(
             spv::StorageClass::IncomingRayPayloadNV, param);
+        currentRayPayload = curStageVar;
       } else {
         // Second argument is always attribute
         curStageVar = declIdMapper.createRayTracingNVStageVar(

+ 5 - 0
tools/clang/lib/SPIRV/SpirvEmitter.h

@@ -1145,6 +1145,11 @@ private:
   llvm::SmallDenseMap<QualType,
                       std::pair<SpirvInstruction *, SpirvInstruction *>, 4>
       callDataMap;
+
+  /// Incoming ray payload for current entry function being translated.
+  /// Only valid for any-hit/closest-hit ray tracing shaders.
+  SpirvInstruction *currentRayPayload;
+
   /// This is the Patch Constant Function. This function is not explicitly
   /// called from the entry point function.
   FunctionDecl *patchConstFunc;

+ 6 - 2
tools/clang/test/CodeGenSPIRV/raytracing.nv.anyhit.hlsl

@@ -69,10 +69,14 @@ void main(inout Payload MyPayload, in Attribute MyAttr) {
   uint _16 = HitKind();
 
   if (_16 == 1U) {
-// CHECK:  OpIgnoreIntersectionNV
+// CHECK:  [[payloadread0:%\d+]] = OpLoad %Payload %MyPayload_0
+// CHECK-NEXT : OpStore %MyPayload [[payloadread0]]
+// CHECK-NEXT : OpIgnoreIntersectionNV
     IgnoreHit();
   } else {
-// CHECK:  OpTerminateRayNV
+// CHECK:  [[payloadread1:%\d+]] = OpLoad %Payload %MyPayload_0
+// CHECK-NEXT : OpStore %MyPayload [[payloadread1]]
+// CHECK-NEXT : OpTerminateRayNV
     AcceptHitAndEndSearch();
   }
 }