Selaa lähdekoodia

[spirv] Use temporary variable for method call on static variable (#973)

Static variables are in the Private storage class but all methods
are generated to take pointers of the Function storage class. So
to call a method on a static variable, we need first create a
temporary variable initialized with the contents of the static
variable. After the method call, we also need to write the contents
back to the static variable in case there are side effects.
Lei Zhang 7 vuotta sitten
vanhempi
commit
46c0e3cf2f

+ 37 - 7
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -1603,7 +1603,11 @@ SpirvEvalInfo SPIRVEmitter::processCall(const CallExpr *callExpr) {
   }
 
   const auto numParams = callee->getNumParams();
+
   bool isNonStaticMemberCall = false;
+  QualType objectType = {};         // Type of the object (if exists)
+  SpirvEvalInfo objectEvalInfo = 0; // EvalInfo for the object (if exists)
+  bool objectNeedsTempVar = false;  // Temporary variable for lvalue object
 
   llvm::SmallVector<uint32_t, 4> params;    // Temporary variables
   llvm::SmallVector<SpirvEvalInfo, 4> args; // Evaluated arguments
@@ -1615,16 +1619,32 @@ SpirvEvalInfo SPIRVEmitter::processCall(const CallExpr *callExpr) {
       // For non-static member calls, evaluate the object and pass it as the
       // first argument.
       const auto *object = memberCall->getImplicitObjectArgument();
-      const auto objectEvalInfo = doExpr(object);
+      object = object->IgnoreParenNoopCasts(astContext);
+
+      objectType = object->getType();
+      objectEvalInfo = doExpr(object);
       uint32_t objectId = objectEvalInfo;
 
       // If not already a variable, we need to create a temporary variable and
       // pass the object pointer to the function. Example:
       // getObject().objectMethod();
-      if (objectEvalInfo.isRValue()) {
-        const auto objType = object->getType();
-        objectId = createTemporaryVar(objType, TypeTranslator::getName(objType),
-                                      objectEvalInfo);
+      bool needsTempVar = objectEvalInfo.isRValue();
+
+      // Try to see if we are calling methods on a global variable, which is put
+      // in the Private storage class. We also need to create temporary variable
+      // for it since the function signature expects all arguments in the
+      // Function storage class.
+      if (!needsTempVar)
+        if (const auto *declRefExpr = dyn_cast<DeclRefExpr>(object))
+          if (const auto *refDecl = declRefExpr->getFoundDecl())
+            if (const auto *varDecl = dyn_cast<VarDecl>(refDecl))
+              needsTempVar = objectNeedsTempVar = varDecl->hasGlobalStorage();
+
+      if (needsTempVar) {
+        objectId =
+            createTemporaryVar(objectType, TypeTranslator::getName(objectType),
+                               // May need to load to use as initializer
+                               loadIfGLValue(object, objectEvalInfo));
       }
 
       args.push_back(objectId);
@@ -1675,6 +1695,15 @@ SpirvEvalInfo SPIRVEmitter::processCall(const CallExpr *callExpr) {
   const uint32_t retVal =
       theBuilder.createFunctionCall(retType, funcId, params);
 
+  // If we created a temporary variable for the object this method is invoked
+  // upon, we need to copy the contents in the temporary variable back to the
+  // original object's variable in case there are side effects.
+  if (objectNeedsTempVar) {
+    const uint32_t typeId = typeTranslator.translateType(objectType);
+    const uint32_t value = theBuilder.createLoad(typeId, params.front());
+    storeValue(objectEvalInfo, value, objectType);
+  }
+
   // Go through all parameters and write those marked as out/inout
   for (uint32_t i = 0; i < numParams; ++i) {
     const auto *param = callee->getParamDecl(i);
@@ -3570,8 +3599,9 @@ SpirvEvalInfo SPIRVEmitter::doMemberExpr(const MemberExpr *expr) {
 }
 
 uint32_t SPIRVEmitter::createTemporaryVar(QualType type, llvm::StringRef name,
-                                          uint32_t init) {
-  const uint32_t varType = typeTranslator.translateType(type);
+                                          const SpirvEvalInfo &init) {
+  const uint32_t varType =
+      typeTranslator.translateType(type, init.getLayoutRule());
   const std::string varName = "temp.var." + name.str();
   const uint32_t varId = theBuilder.addFnVar(varType, varName);
   theBuilder.createStore(varId, init);

+ 1 - 1
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -253,7 +253,7 @@ private:
   /// Returns the <result-id> of the variable.
   uint32_t SPIRVEmitter::createTemporaryVar(QualType varType,
                                             llvm::StringRef varName,
-                                            uint32_t initValue);
+                                            const SpirvEvalInfo &initValue);
 
   /// Collects all indices (SPIR-V constant values) from consecutive MemberExprs
   /// or ArraySubscriptExprs or operator[] calls and writes into indices.

+ 20 - 0
tools/clang/test/CodeGenSPIRV/oo.method.on-static-var.hlsl

@@ -0,0 +1,20 @@
+// Run: %dxc -T vs_6_0 -E main
+
+struct S {
+    float val;
+
+    float getVal() { return val; }
+};
+
+static S gSVar = {4.2};
+
+float main() : A {
+// CHECK:      %temp_var_S = OpVariable %_ptr_Function_S Function
+
+// CHECK:       [[s:%\d+]] = OpLoad %S %gSVar
+// CHECK-NEXT:               OpStore %temp_var_S [[s]]
+// CHECK-NEXT:    {{%\d+}} = OpFunctionCall %float %S_getVal %temp_var_S
+// CHECK-NEXT:  [[s:%\d+]] = OpLoad %S %temp_var_S
+// CHECK-NEXT:               OpStore %gSVar [[s]]
+    return gSVar.getVal();
+}

+ 3 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -397,6 +397,9 @@ TEST_F(FileTest, ClassStaticMember) {
 TEST_F(FileTest, StaticMemberInitializer) {
   runFileTest("oo.static.member.init.hlsl");
 }
+TEST_F(FileTest, MethodCallOnStaticVar) {
+  runFileTest("oo.method.on-static-var.hlsl");
+}
 
 // For semantics
 // SV_Position, SV_ClipDistance, and SV_CullDistance are covered in