Quellcode durchsuchen

[spirv] Avoid creating temporary variables for local variables (#1293)

When calling functions with arguments that references local
variables as a whole and annotated with out/inout, we don't
need to create temporary variables for them again. Just pass
the pointers to the original local variables.
Lei Zhang vor 7 Jahren
Ursprung
Commit
b17ed7f06d

+ 41 - 21
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -2001,7 +2001,8 @@ SpirvEvalInfo SPIRVEmitter::processCall(const CallExpr *callExpr) {
   SpirvEvalInfo objectEvalInfo = 0; // EvalInfo for the object (if exists)
   bool needsTempVar = false;        // Whether we need temporary variable.
 
-  llvm::SmallVector<uint32_t, 4> params;    // Temporary variables
+  llvm::SmallVector<uint32_t, 4> vars;      // Variables for function call
+  llvm::SmallVector<bool, 4> isTempVar;     // Temporary variable or not
   llvm::SmallVector<SpirvEvalInfo, 4> args; // Evaluated arguments
 
   if (const auto *memberCall = dyn_cast<CXXMemberCallExpr>(callExpr)) {
@@ -2040,7 +2041,8 @@ SpirvEvalInfo SPIRVEmitter::processCall(const CallExpr *callExpr) {
       args.push_back(objectId);
       // We do not need to create a new temporary variable for the this
       // object. Use the evaluated argument.
-      params.push_back(args.back());
+      vars.push_back(args.back());
+      isTempVar.push_back(false);
     }
   }
 
@@ -2052,25 +2054,44 @@ SpirvEvalInfo SPIRVEmitter::processCall(const CallExpr *callExpr) {
     auto *arg = callExpr->getArg(i)->IgnoreParenLValueCasts();
     const auto *param = callee->getParamDecl(i);
 
-    // We need to create variables for holding the values to be used as
-    // arguments. The variables themselves are of pointer types.
-    const uint32_t varType =
-        declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(param);
-    const std::string varName = "param.var." + param->getNameAsString();
-    const uint32_t tempVarId = theBuilder.addFnVar(varType, varName);
+    // Get the evaluation info if this argument is referencing some variable
+    // *as a whole*, in which case we can avoid creating the temporary variable
+    // for it if it is Function scope and can act as out parameter.
+    SpirvEvalInfo argInfo = 0;
+    if (const auto *declRefExpr = dyn_cast<DeclRefExpr>(arg)) {
+      argInfo = declIdMapper.getDeclEvalInfo(declRefExpr->getDecl());
+    }
+
+    if (argInfo && argInfo.getStorageClass() == spv::StorageClass::Function &&
+        canActAsOutParmVar(param)) {
+      vars.push_back(argInfo);
+      isTempVar.push_back(false);
+      args.push_back(doExpr(arg));
+    } else {
+      // We need to create variables for holding the values to be used as
+      // arguments. The variables themselves are of pointer types.
+      const uint32_t varType =
+          declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(param);
+      const std::string varName = "param.var." + param->getNameAsString();
+      const uint32_t tempVarId = theBuilder.addFnVar(varType, varName);
 
-    params.push_back(tempVarId);
-    args.push_back(doExpr(arg));
+      vars.push_back(tempVarId);
+      isTempVar.push_back(true);
+      args.push_back(doExpr(arg));
 
-    // Update counter variable associated with function parameters
-    tryToAssignCounterVar(param, arg);
+      // Update counter variable associated with function parameters
+      tryToAssignCounterVar(param, arg);
 
-    // Manually load the argument here
-    const auto rhsVal = loadIfGLValue(arg, args.back());
-    // Initialize the temporary variables using the contents of the arguments
-    storeValue(tempVarId, rhsVal, param->getType());
+      // Manually load the argument here
+      const auto rhsVal = loadIfGLValue(arg, args.back());
+      // Initialize the temporary variables using the contents of the arguments
+      storeValue(tempVarId, rhsVal, param->getType());
+    }
   }
 
+  assert(vars.size() == isTempVar.size());
+  assert(vars.size() == args.size());
+
   // Push the callee into the work queue if it is not there.
   if (!workQueue.count(callee)) {
     workQueue.insert(callee);
@@ -2081,26 +2102,25 @@ SpirvEvalInfo SPIRVEmitter::processCall(const CallExpr *callExpr) {
   // Get or forward declare the function <result-id>
   const uint32_t funcId = declIdMapper.getOrRegisterFnResultId(callee);
 
-  const uint32_t retVal =
-      theBuilder.createFunctionCall(retType, funcId, params);
+  const uint32_t retVal = theBuilder.createFunctionCall(retType, funcId, vars);
 
   // If we created a temporary variable for the lvalue object this method is
   // invoked upon, we need to copy the contents in the temporary variable back
   // to the original object's variable in case there are side effects.
   if (needsTempVar && !objectEvalInfo.isRValue()) {
     const uint32_t typeId = typeTranslator.translateType(objectType);
-    const uint32_t value = theBuilder.createLoad(typeId, params.front());
+    const uint32_t value = theBuilder.createLoad(typeId, vars.front());
     storeValue(objectEvalInfo, value, objectType);
   }
 
   // Go through all parameters and write those marked as out/inout
   for (uint32_t i = 0; i < numParams; ++i) {
     const auto *param = callee->getParamDecl(i);
-    if (canActAsOutParmVar(param)) {
+    if (isTempVar[i] && canActAsOutParmVar(param)) {
       const auto *arg = callExpr->getArg(i);
       const uint32_t index = i + isNonStaticMemberCall;
       const uint32_t typeId = typeTranslator.translateType(param->getType());
-      const uint32_t value = theBuilder.createLoad(typeId, params[index]);
+      const uint32_t value = theBuilder.createLoad(typeId, vars[index]);
 
       processAssignment(arg, value, false, args[index]);
     }

+ 1 - 19
tools/clang/test/CodeGenSPIRV/fn.param.inout.hlsl

@@ -11,36 +11,18 @@ float fnInOut(uniform float a, in float b, out float c, inout float d, inout Pix
 }
 
 float main(float val: A) : B {
-// CHECK-LABEL: %src_main = OpFunction
     float m, n;
     Pixel p;
 
 // CHECK:      %param_var_a = OpVariable %_ptr_Function_float Function
 // CHECK-NEXT: %param_var_b = OpVariable %_ptr_Function_float Function
-// CHECK-NEXT: %param_var_c = OpVariable %_ptr_Function_float Function
-// CHECK-NEXT: %param_var_d = OpVariable %_ptr_Function_float Function
-// CHECK-NEXT: %param_var_e = OpVariable %_ptr_Function_Pixel Function
 
 // CHECK-NEXT:                OpStore %param_var_a %float_5
 // CHECK-NEXT: [[val:%\d+]] = OpLoad %float %val
 // CHECK-NEXT:                OpStore %param_var_b [[val]]
-// CHECK-NEXT:   [[m:%\d+]] = OpLoad %float %m
-// CHECK-NEXT:                OpStore %param_var_c [[m]]
-// CHECK-NEXT:   [[n:%\d+]] = OpLoad %float %n
-// CHECK-NEXT:                OpStore %param_var_d [[n]]
-// CHECK-NEXT:   [[p:%\d+]] = OpLoad %Pixel %p
-// CHECK-NEXT:                OpStore %param_var_e [[p]]
 
-// CHECK-NEXT: [[ret:%\d+]] = OpFunctionCall %float %fnInOut %param_var_a %param_var_b %param_var_c %param_var_d
-
-// CHECK-NEXT:   [[c:%\d+]] = OpLoad %float %param_var_c
-// CHECK-NEXT:                OpStore %m [[c]]
-// CHECK-NEXT:   [[d:%\d+]] = OpLoad %float %param_var_d
-// CHECK-NEXT:                OpStore %n [[d]]
-// CHECK-NEXT:   [[e:%\d+]] = OpLoad %Pixel %param_var_e
-// CHECK-NEXT:                OpStore %p [[e]]
+// CHECK-NEXT: [[ret:%\d+]] = OpFunctionCall %float %fnInOut %param_var_a %param_var_b %m %n %p
 
 // CHECK-NEXT:                OpReturnValue [[ret]]
     return fnInOut(5., val, m, n, p);
-// CHECK-NEXT: OpFunctionEnd
 }

+ 37 - 0
tools/clang/test/CodeGenSPIRV/fn.param.inout.no-copy.hlsl

@@ -0,0 +1,37 @@
+// Run: %dxc -T vs_6_0 -E main
+
+struct S {
+    float4 val;
+};
+
+void foo(
+    out   int      a,
+    inout uint2    b,
+    out   float2x3 c,
+    inout S        d,
+    out   float    e[4]
+) {
+    a = 0;
+    b = 1;
+    c = 2.0;
+    d.val = 3.0;
+    e[0] = 4.0;
+}
+
+void main() {
+    int      a;
+    uint2    b;
+    float2x3 c;
+    S        d;
+    float    e[4];
+
+// CHECK: %a = OpVariable %_ptr_Function_int Function
+// CHECK: %b = OpVariable %_ptr_Function_v2uint Function
+// CHECK: %c = OpVariable %_ptr_Function_mat2v3float Function
+// CHECK: %d = OpVariable %_ptr_Function_S Function
+// CHECK: %e = OpVariable %_ptr_Function__arr_float_uint_4 Function
+
+// CHECK:      OpFunctionCall %void %foo %a %b %c %d %e
+
+    foo(a, b, c, d, e);
+}

+ 1 - 3
tools/clang/test/CodeGenSPIRV/fn.param.inout.vector.hlsl

@@ -18,9 +18,7 @@ float4 main() : C {
 
     float4 val;
 // CHECK:    [[z_ptr:%\d+]] = OpAccessChain %_ptr_Function_float %val %int_2
-// CHECK:          {{%\d+}} = OpFunctionCall %void %bar %param_var_x %param_var_y %param_var_z %param_var_w
-// CHECK-NEXT:   [[x:%\d+]] = OpLoad %v4float %param_var_x
-// CHECK-NEXT:                OpStore %val [[x]]
+// CHECK:          {{%\d+}} = OpFunctionCall %void %bar %val %param_var_y %param_var_z %param_var_w
 // CHECK-NEXT:   [[y:%\d+]] = OpLoad %v3float %param_var_y
 // CHECK-NEXT: [[old:%\d+]] = OpLoad %v4float %val
     // Write to val.zwx:

+ 9 - 1
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -435,13 +435,21 @@ TEST_F(FileTest, ControlFlowConditionalOp) { runFileTest("cf.cond-op.hlsl"); }
 // For functions
 TEST_F(FileTest, FunctionCall) { runFileTest("fn.call.hlsl"); }
 TEST_F(FileTest, FunctionDefaultArg) { runFileTest("fn.default-arg.hlsl"); }
-TEST_F(FileTest, FunctionInOutParam) { runFileTest("fn.param.inout.hlsl"); }
+TEST_F(FileTest, FunctionInOutParam) {
+  // Tests using uniform/in/out/inout annotations on function parameters
+  runFileTest("fn.param.inout.hlsl");
+}
 TEST_F(FileTest, FunctionInOutParamVector) {
   runFileTest("fn.param.inout.vector.hlsl");
 }
 TEST_F(FileTest, FunctionInOutParamDiffStorageClass) {
   runFileTest("fn.param.inout.storage-class.hlsl");
 }
+TEST_F(FileTest, FunctionInOutParamNoNeedToCopy) {
+  // Tests that referencing function scope variables as a whole with out/inout
+  // annotation does not create temporary variables
+  runFileTest("fn.param.inout.no-copy.hlsl");
+}
 TEST_F(FileTest, FunctionFowardDeclaration) {
   runFileTest("fn.foward-declaration.hlsl");
 }