Sfoglia il codice sorgente

Support nested struct type when preprocessArgUsedInCall.

Xiang Li 8 anni fa
parent
commit
f47fc4d40b

+ 25 - 17
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -5366,6 +5366,30 @@ void SROA_Parameter_HLSL::flattenArgument(
 
 }
 
+static bool IsUsedAsCallArg(Value *V) {
+  for (User *U : V->users()) {
+    if (CallInst *CI = dyn_cast<CallInst>(U)) {
+      Function *CalledF = CI->getCalledFunction();
+      HLOpcodeGroup group = GetHLOpcodeGroup(CalledF);
+      // Skip HL operations.
+      if (group != HLOpcodeGroup::NotHL ||
+          group == HLOpcodeGroup::HLExtIntrinsic) {
+        continue;
+      }
+      // Skip llvm intrinsic.
+      if (CalledF->isIntrinsic())
+        continue;
+
+      return true;
+    }
+    if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
+      if (IsUsedAsCallArg(GEP))
+        return true;
+    }
+  }
+  return false;
+}
+
 // For function parameter which used in function call and need to be flattened.
 // Replace with tmp alloca.
 void SROA_Parameter_HLSL::preprocessArgUsedInCall(Function *F) {
@@ -5397,24 +5421,8 @@ void SROA_Parameter_HLSL::preprocessArgUsedInCall(Function *F) {
     if (!Ty->isAggregateType() &&
         Ty->getScalarType() == Ty)
       continue;
-    bool bUsedInCall = false;
-    for (User *U : arg.users()) {
-      if (CallInst *CI = dyn_cast<CallInst>(U)) {
-        Function *CalledF = CI->getCalledFunction();
-        HLOpcodeGroup group = GetHLOpcodeGroup(CalledF);
-        // Skip HL operations.
-        if (group != HLOpcodeGroup::NotHL ||
-            group == HLOpcodeGroup::HLExtIntrinsic) {
-          continue;
-        }
-        // Skip llvm intrinsic.
-        if (CalledF->isIntrinsic())
-          continue;
 
-        bUsedInCall = true;
-        break;
-      }
-    }
+    bool bUsedInCall = IsUsedAsCallArg(&arg);
 
     if (bUsedInCall) {
       // Create tmp.

+ 22 - 0
tools/clang/test/CodeGenHLSL/shader-compat-suite/lib_arg_flatten/lib_arg_flatten4.hlsl

@@ -0,0 +1,22 @@
+// RUN: %dxc -T lib_6_1 %s | FileCheck %s
+
+// Make sure nested struct parameter not replaced as function call arg.
+
+// CHECK: call void @"\01?test_extern@@YAMUFoo@@@Z"(float %{{.*}}, float* nonnull %{{.*}})
+
+struct Foo {
+  float a;
+};
+
+struct Bar {
+  Foo foo;
+  float b;
+};
+
+float test_extern(Foo foo);
+
+float test(Bar b)
+{
+  float x = test_extern(b.foo);
+  return x + b.b;
+}