浏览代码

[spirv] Member fn call use Function storage class. (#1226)

Fixes https://github.com/Microsoft/DirectXShaderCompiler/issues/1222
Ehsan 7 年之前
父节点
当前提交
483901edba
共有 2 个文件被更改,包括 31 次插入19 次删除
  1. 10 16
      tools/clang/lib/SPIRV/SPIRVEmitter.cpp
  2. 21 3
      tools/clang/test/CodeGenSPIRV/oo.struct.method.hlsl

+ 10 - 16
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -1957,7 +1957,7 @@ SpirvEvalInfo SPIRVEmitter::processCall(const CallExpr *callExpr) {
   bool isNonStaticMemberCall = false;
   QualType objectType = {};         // Type of the object (if exists)
   SpirvEvalInfo objectEvalInfo = 0; // EvalInfo for the object (if exists)
-  bool objectNeedsTempVar = false;  // Temporary variable for lvalue object
+  bool needsTempVar = false;        // Whether we need temporary variable.
 
   llvm::SmallVector<uint32_t, 4> params;    // Temporary variables
   llvm::SmallVector<SpirvEvalInfo, 4> args; // Evaluated arguments
@@ -1982,17 +1982,11 @@ SpirvEvalInfo SPIRVEmitter::processCall(const CallExpr *callExpr) {
       // If not already a variable, we need to create a temporary variable and
       // pass the object pointer to the function. Example:
       // getObject().objectMethod();
-      bool needsTempVar = objectEvalInfo.isRValue();
-
-      // Try to see if we are calling methods on a global variable, which is put
-      // in the Private storage class. We also need to create temporary variable
-      // for it since the function signature expects all arguments in the
-      // Function storage class.
-      if (!needsTempVar)
-        if (const auto *declRefExpr = dyn_cast<DeclRefExpr>(object))
-          if (const auto *refDecl = declRefExpr->getFoundDecl())
-            if (const auto *varDecl = dyn_cast<VarDecl>(refDecl))
-              needsTempVar = objectNeedsTempVar = varDecl->hasGlobalStorage();
+      // Also, any parameter passed to the member function must be of Function
+      // storage class.
+      needsTempVar =
+          objectEvalInfo.isRValue() ||
+          objectEvalInfo.getStorageClass() != spv::StorageClass::Function;
 
       if (needsTempVar) {
         objectId =
@@ -2048,10 +2042,10 @@ SpirvEvalInfo SPIRVEmitter::processCall(const CallExpr *callExpr) {
   const uint32_t retVal =
       theBuilder.createFunctionCall(retType, funcId, params);
 
-  // If we created a temporary variable for the object this method is invoked
-  // upon, we need to copy the contents in the temporary variable back to the
-  // original object's variable in case there are side effects.
-  if (objectNeedsTempVar) {
+  // If we created a temporary variable for the lvalue object this method is
+  // invoked upon, we need to copy the contents in the temporary variable back
+  // to the original object's variable in case there are side effects.
+  if (needsTempVar && !objectEvalInfo.isRValue()) {
     const uint32_t typeId = typeTranslator.translateType(objectType);
     const uint32_t value = theBuilder.createLoad(typeId, params.front());
     storeValue(objectEvalInfo, value, objectType);

+ 21 - 3
tools/clang/test/CodeGenSPIRV/oo.struct.method.hlsl

@@ -59,6 +59,12 @@ T foo() {
   return t;
 }
 
+struct R {
+  int a;
+  void incr() { ++a; }
+};
+RWStructuredBuffer<R> rwsb;
+
 // CHECK:     [[ft_f32:%\d+]] = OpTypeFunction %float
 // CHECK:       [[ft_S:%\d+]] = OpTypeFunction %float %_ptr_Function_S
 // CHECK:   [[ft_S_f32:%\d+]] = OpTypeFunction %float %_ptr_Function_S %_ptr_Function_float
@@ -98,6 +104,18 @@ float main() : A {
 // CHECK-NEXT:        {{%\d+}} = OpFunctionCall %float %S_fn_ref %temp_var_S
   float f2 = foo().get_S().fn_ref();
 
+// CHECK:         [[uniformPtr:%\d+]] = OpAccessChain %_ptr_Uniform_R %rwsb %int_0 %uint_0
+// CHECK-NEXT:   [[originalObj:%\d+]] = OpLoad %R [[uniformPtr]]
+// CHECK-NEXT:        [[member:%\d+]] = OpCompositeExtract %int [[originalObj]] 0
+// CHECK-NEXT:       [[tempVar:%\d+]] = OpCompositeConstruct %R_0 [[member]]
+// CHECK-NEXT:                          OpStore %temp_var_R [[tempVar]]
+// CHECK-NEXT:                          OpFunctionCall %void %R_incr %temp_var_R
+// CHECK-NEXT:       [[tempVar:%\d+]] = OpLoad %R_0 %temp_var_R
+// CHECK-NEXT: [[tempVarMember:%\d+]] = OpCompositeExtract %int [[tempVar]] 0
+// CHECK-NEXT:          [[newR:%\d+]] = OpCompositeConstruct %R [[tempVarMember]]
+// CHECK-NEXT:                          OpStore [[uniformPtr]] [[newR]]
+  rwsb[0].incr();
+
   return f1;
 // CHECK:                     OpFunctionEnd
 }
@@ -151,8 +169,8 @@ float main() : A {
 // CHECK:                      OpFunctionEnd
 
 // CHECK:        %S_fn_param = OpFunction %float None [[ft_S_f32]]
-// CHECK-NEXT: %param_this_4 = OpFunctionParameter %_ptr_Function_S
+// CHECK-NEXT: %param_this_5 = OpFunctionParameter %_ptr_Function_S
 // CHECK-NEXT:          %c_0 = OpFunctionParameter %_ptr_Function_float
-// CHECK-NEXT:   %bb_entry_8 = OpLabel
-// CHECK:           {{%\d+}} = OpAccessChain %_ptr_Function_float %param_this_4 %int_0
+// CHECK-NEXT:                 OpLabel
+// CHECK:                      OpAccessChain %_ptr_Function_float %param_this_5 %int_0
 // CHECK:                      OpFunctionEnd