2
0
Эх сурвалжийг харах

[spirv] Fix codegen for inout resource function parameters. (#3112)

Ehsan 5 жил өмнө
parent
commit
c5e4626cf2

+ 12 - 2
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -2056,8 +2056,13 @@ SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
     // If argInfo is nullptr and argInst is a rvalue, we do not have a proper
     // pointer to pass to the function. we need a temporary variable in that
     // case.
+    //
+    // If we have an 'out/inout' resource as function argument, we need to
+    // create a temporary variable for it because the function definition
+    // expects are point-to-pointer argument for resources, which will be
+    // resolved by legalization.
     if ((argInfo || (argInst && !argInst->isRValue())) &&
-        canActAsOutParmVar(param) &&
+        canActAsOutParmVar(param) && !isResourceType(param) &&
         paramTypeMatchesArgType(param->getType(), arg->getType())) {
       // Based on SPIR-V spec, function parameter must be always Function
       // scope. In addition, we must pass memory object declaration argument
@@ -2141,7 +2146,12 @@ SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
     // If it calls a non-static member function, the object itself is argument
     // 0, and therefore all other argument positions are shifted by 1.
     const uint32_t index = i + isNonStaticMemberCall;
-    if (isTempVar[index] && canActAsOutParmVar(param)) {
+    // Using a resouce as a function parameter is never passed-by-copy. As a
+    // result, even if the function parameter is marked as 'out' or 'inout',
+    // there is no reason to copy back the results after the function call into
+    // the resource.
+    if (isTempVar[index] && canActAsOutParmVar(param) &&
+        !isResourceType(param)) {
       const auto *arg = callExpr->getArg(i);
       SpirvInstruction *value = spvBuilder.createLoad(
           param->getType(), vars[index], arg->getLocStart());

+ 21 - 0
tools/clang/test/CodeGenSPIRV/fn.param.inout.resource.hlsl

@@ -0,0 +1,21 @@
+// Run: %dxc -E main -T cs_6_0
+
+RWStructuredBuffer<int> testrwbuf : register(u0);
+
+void testfn(uint index, uint value, inout RWStructuredBuffer<int> buf);
+
+[numthreads(1, 1, 1)]
+void main(uint3 GroupId          : SV_GroupID,
+          uint3 DispatchThreadId : SV_DispatchThreadID)
+{
+// CHECK: %param_var_buf = OpVariable %_ptr_Function__ptr_Uniform_type_RWStructuredBuffer_int Function
+// CHECK:       {{%\d+}} = OpFunctionCall %void %testfn %param_var_index %param_var_value %param_var_buf
+  testfn(GroupId.x, DispatchThreadId.x, testrwbuf);
+}
+
+// CHECK:   %buf = OpFunctionParameter %_ptr_Function__ptr_Uniform_type_RWStructuredBuffer_int
+void testfn(uint index, uint value, inout RWStructuredBuffer<int> buf) {
+  buf[index] = value;
+}
+
+

+ 4 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -547,6 +547,10 @@ TEST_F(FileTest, FunctionInOutParamVector) {
   setBeforeHLSLLegalization();
   runFileTest("fn.param.inout.vector.hlsl");
 }
+TEST_F(FileTest, FunctionInOutParamResource) {
+  setBeforeHLSLLegalization();
+  runFileTest("fn.param.inout.resource.hlsl");
+}
 TEST_F(FileTest, FunctionInOutParamDiffStorageClass) {
   setBeforeHLSLLegalization();
   runFileTest("fn.param.inout.storage-class.hlsl");