فهرست منبع

Merged PR 101: Store outputs before IgnoreHit/AcceptHitAndEndSearch

Store outputs before IgnoreHit/AcceptHitAndEndSearch

SROA_Parameter_HLSL would copy arguments to/from allocas at entry and each return.  These instrinsics need outputs to have been updated before the call, and do not return (have no-return attribute).

Make sure we copy to outputs before the function call by:
- Split block and add return immediately after call
- SROA_ParameterHLSL will then insert copy to outputs before each return
- In HLOperationLower, move call to just before the return (after copy to outputs)
- Later, the return (all code following the no-return call) will be replaced with 'unreachable'

Also fix build error in DxilPatchShaderRecordBindings.cpp with unused local var cast (void*)index;
Tex Riddell 7 سال پیش
والد
کامیت
b4f0e37581

+ 1 - 1
lib/HLSL/DxilPatchShaderRecordBindings.cpp

@@ -840,7 +840,7 @@ void DxilPatchShaderRecordBindings::InitializeViewTable() {
           pInputShaderInfo->pUAVRegisterSpaceArray, 
           *pInputShaderInfo->pNumUAVSpaces, 
           FallbackLayerNumDescriptorHeapSpacesPerView);
-        (void*)index;
+        (void)index;
         assert(index == 0);
     }
 }

+ 17 - 2
lib/HLSL/HLOperationLower.cpp

@@ -4373,6 +4373,21 @@ Value *TranslateNoArgMatrixOperation(CallInst *CI, IntrinsicOp IOP, OP::OpCode o
   return retVal;
 }
 
+Value *TranslateNoArgNoReturnPreserveOutput(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode,
+  HLOperationLowerHelper &helper, HLObjectOperationLowerHelper *pObjHelper, bool &Translated) {
+  Instruction *pResult = cast<Instruction>(
+    TrivialNoArgOperation(CI, IOP, opcode, helper, pObjHelper, Translated));
+  // HL intrinsic must have had a return injected just after the call.
+  // SROA_Parameter_HLSL will copy from alloca to output just before each return.
+  // Now move call after the copy and just before the return.
+  if (isa<ReturnInst>(pResult->getNextNode()))
+    return pResult;
+  ReturnInst *RetI = cast<ReturnInst>(pResult->getParent()->getTerminator());
+  pResult->removeFromParent();
+  pResult->insertBefore(RetI);
+  return pResult;
+}
+
 } // namespace
 
 // Lower table.
@@ -4408,7 +4423,7 @@ Value *StreamOutputLower(CallInst *CI, IntrinsicOp IOP, DXIL::OpCode opcode,
 
 // This table has to match IntrinsicOp orders
 IntrinsicLower gLowerTable[static_cast<unsigned>(IntrinsicOp::Num_Intrinsics)] = {
-    {IntrinsicOp::IOP_AcceptHitAndEndSearch, TrivialNoArgOperation, DXIL::OpCode::AcceptHitAndEndSearch},
+    {IntrinsicOp::IOP_AcceptHitAndEndSearch, TranslateNoArgNoReturnPreserveOutput, DXIL::OpCode::AcceptHitAndEndSearch},
     {IntrinsicOp::IOP_AddUint64,  TranslateAddUint64,  DXIL::OpCode::UAddc},
     {IntrinsicOp::IOP_AllMemoryBarrier, TrivialBarrier, DXIL::OpCode::Barrier},
     {IntrinsicOp::IOP_AllMemoryBarrierWithGroupSync, TrivialBarrier, DXIL::OpCode::Barrier},
@@ -4428,7 +4443,7 @@ IntrinsicLower gLowerTable[static_cast<unsigned>(IntrinsicOp::Num_Intrinsics)] =
     {IntrinsicOp::IOP_GroupMemoryBarrier, TrivialBarrier, DXIL::OpCode::Barrier},
     {IntrinsicOp::IOP_GroupMemoryBarrierWithGroupSync, TrivialBarrier, DXIL::OpCode::Barrier},
     {IntrinsicOp::IOP_HitKind, TrivialNoArgWithRetOperation, DXIL::OpCode::HitKind},
-    {IntrinsicOp::IOP_IgnoreHit, TrivialNoArgOperation, DXIL::OpCode::IgnoreHit},
+    {IntrinsicOp::IOP_IgnoreHit, TranslateNoArgNoReturnPreserveOutput, DXIL::OpCode::IgnoreHit},
     {IntrinsicOp::IOP_InstanceID, TrivialNoArgWithRetOperation, DXIL::OpCode::InstanceID},
     {IntrinsicOp::IOP_InstanceIndex, TrivialNoArgWithRetOperation, DXIL::OpCode::InstanceIndex},
     {IntrinsicOp::IOP_InterlockedAdd, TranslateIopAtomicBinaryOperation, DXIL::OpCode::NumOpCodes},

+ 41 - 0
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -4184,6 +4184,7 @@ static void LegalizeDxilInputOutputs(Function *F,
                                      DxilFunctionAnnotation *EntryAnnotation,
                                      const DataLayout &DL,
                                      DxilTypeSystem &typeSys);
+static void InjectReturnAfterNoReturnPreserveOutput(HLModule &HLM);
 
 namespace {
 class SROA_Parameter_HLSL : public ModulePass {
@@ -4205,6 +4206,8 @@ public:
     // used to load them.
     m_HasDbgInfo = getDebugMetadataVersionFromModule(M) != 0;
 
+    InjectReturnAfterNoReturnPreserveOutput(*m_pHLModule);
+
     std::deque<Function *> WorkList;
     std::vector<Function *> DeadHLFunctions;
     for (Function &F : M.functions()) {
@@ -5730,6 +5733,44 @@ static void CheckArgUsage(Value *V, bool &bLoad, bool &bStore) {
     }
   }
 }
+
+// AcceptHitAndEndSearch and IgnoreHit both will not return, but require
+// outputs to have been written before the call.  Do this by:
+//  - inject a return immediately after the call if not there already
+//  - LegalizeDxilInputOutputs will inject writes from temp alloca to
+//    outputs before each return.
+//  - in HLOperationLower, after lowering the intrinsic, move the intrinsic
+//    to just before the return.
+static void InjectReturnAfterNoReturnPreserveOutput(HLModule &HLM) {
+  for (Function &F : HLM.GetModule()->functions()) {
+    if (GetHLOpcodeGroup(&F) == HLOpcodeGroup::HLIntrinsic) {
+      for (auto U : F.users()) {
+        if (CallInst *CI = dyn_cast<CallInst>(U)) {
+          unsigned OpCode = GetHLOpcode(CI);
+          if (OpCode == (unsigned)IntrinsicOp::IOP_AcceptHitAndEndSearch ||
+              OpCode == (unsigned)IntrinsicOp::IOP_IgnoreHit) {
+            Instruction *pNextI = CI->getNextNode();
+            // Skip if already has a return immediatly following call
+            if (isa<ReturnInst>(pNextI))
+              continue;
+            // split block and add return:
+            BasicBlock *BB = CI->getParent();
+            BasicBlock *NewBB = BB->splitBasicBlock(pNextI);
+            TerminatorInst *Term = NewBB->getTerminator();
+            Term->eraseFromParent();
+            IRBuilder<> Builder(NewBB);
+            llvm::Type *RetTy = CI->getParent()->getParent()->getReturnType();
+            if (RetTy->isVoidTy())
+              Builder.CreateRetVoid();
+            else
+              Builder.CreateRet(UndefValue::get(RetTy));
+          }
+        }
+      }
+    }
+  }
+}
+
 // Support store to input and load from output.
 static void LegalizeDxilInputOutputs(Function *F,
                                      DxilFunctionAnnotation *EntryAnnotation,

+ 32 - 0
tools/clang/test/CodeGenHLSL/quick-test/raytracing_anyhit_accept_hit.hlsl

@@ -0,0 +1,32 @@
+// RUN: %dxc -T lib_6_3 -auto-binding-space 11 %s | FileCheck %s
+
+// CHECK: %[[color:[^ ]+]] = getelementptr inbounds %struct.MyPayload, %struct.MyPayload* %{{[^ ,]+}}, i32 0, i32 0
+// CHECK: store <4 x float> <float 1.250000e-01, float 2.500000e-01, float 5.000000e-01, float 1.000000e+00>, <4 x float>* %[[color]]
+// CHECK: call void @dx.op.acceptHitAndEndSearch
+// CHECK: unreachable
+// CHECK: %[[pos:[^ ]+]] = getelementptr inbounds %struct.MyPayload, %struct.MyPayload* %{{[^ ,]+}}, i32 0, i32 1
+// CHECK: store <2 x i32> <i32 1, i32 2>, <2 x i32>* %[[pos]]
+// CHECK: call void @dx.op.acceptHitAndEndSearch
+// CHECK: unreachable
+
+struct MyPayload {
+  float4 color;
+  uint2 pos;
+};
+
+struct MyAttributes {
+  float2 bary;
+  uint id;
+};
+
+[shader("anyhit")]
+void anyhit1( inout MyPayload payload : SV_RayPayload,
+              in MyAttributes attr : SV_IntersectionAttributes )
+{
+  float3 hitLocation = ObjectRayOrigin() + ObjectRayDirection() * RayTCurrent();
+  payload.color = float4(0.125, 0.25, 0.5, 1.0);
+  if (hitLocation.x < 0)
+    AcceptHitAndEndSearch();   // aborts function
+  payload.pos = uint2(1,2);
+  AcceptHitAndEndSearch();   // aborts function
+}

+ 32 - 0
tools/clang/test/CodeGenHLSL/quick-test/raytracing_anyhit_ignore_hit.hlsl

@@ -0,0 +1,32 @@
+// RUN: %dxc -T lib_6_3 -auto-binding-space 11 %s | FileCheck %s
+
+// CHECK: %[[color:[^ ]+]] = getelementptr inbounds %struct.MyPayload, %struct.MyPayload* %{{[^ ,]+}}, i32 0, i32 0
+// CHECK: store <4 x float> <float 1.250000e-01, float 2.500000e-01, float 5.000000e-01, float 1.000000e+00>, <4 x float>* %[[color]]
+// CHECK: call void @dx.op.ignoreHit
+// CHECK: unreachable
+// CHECK: %[[pos:[^ ]+]] = getelementptr inbounds %struct.MyPayload, %struct.MyPayload* %{{[^ ,]+}}, i32 0, i32 1
+// CHECK: store <2 x i32> <i32 1, i32 2>, <2 x i32>* %[[pos]]
+// CHECK: call void @dx.op.ignoreHit
+// CHECK: unreachable
+
+struct MyPayload {
+  float4 color;
+  uint2 pos;
+};
+
+struct MyAttributes {
+  float2 bary;
+  uint id;
+};
+
+[shader("anyhit")]
+void anyhit1( inout MyPayload payload : SV_RayPayload,
+              in MyAttributes attr : SV_IntersectionAttributes )
+{
+  float3 hitLocation = ObjectRayOrigin() + ObjectRayDirection() * RayTCurrent();
+  payload.color = float4(0.125, 0.25, 0.5, 1.0);
+  if (hitLocation.x < 0)
+    IgnoreHit();   // aborts function
+  payload.pos = uint2(1,2);
+  IgnoreHit();   // aborts function
+}