Prechádzať zdrojové kódy

Support case write to struct input. (#731)

Xiang Li 7 rokov pred
rodič
commit
ef7a891ab2

+ 5 - 2
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -3726,7 +3726,8 @@ static void ReplaceMemcpy(Value *V, Value *Src, MemCpyInst *MC) {
     Type* TyV = V->getType()->getPointerElementType();
     Type* TySrc = Src->getType()->getPointerElementType();
     if (TyV == TySrc) {
-      V->replaceAllUsesWith(Src);
+      if (V != Src)
+        V->replaceAllUsesWith(Src);
     } else {
       DXASSERT((TyV->isArrayTy() && TySrc->isArrayTy()) &&
                (TyV->getArrayNumElements() == 0 ||
@@ -3764,7 +3765,9 @@ bool SROA_Helper::LowerMemcpy(Value *V, DxilFieldAnnotation *annotation,
   const bool bStructElt = false;
   PointerStatus::analyzePointer(V, PS, typeSys, bStructElt);
   if (bAllowReplace && !PS.HasMultipleAccessingFunctions) {
-    if (PS.StoredType == PointerStatus::StoredType::MemcopyDestOnce) {
+    if (PS.StoredType == PointerStatus::StoredType::MemcopyDestOnce &&
+        // Skip argument for input argument has input value, it is not dest once anymore.
+        !isa<Argument>(V)) {
       // Replace with src of memcpy.
       MemCpyInst *MC = PS.StoringMemcpy;
       if (MC->getSourceAddressSpace() == MC->getDestAddressSpace()) {

+ 18 - 0
tools/clang/test/CodeGenHLSL/writeToInput4.hlsl

@@ -0,0 +1,18 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// Make sure the input is used.
+// CHECK: call float @dx.op.loadInput.f32
+
+struct Input {
+    float a : A;
+    float b : B;
+};
+
+Input ci;
+
+float4 main(Input i) : SV_Target {
+   float c = i.a + i.b;
+   i = ci;
+   c += i.a * i.b;
+   return c;
+}

+ 5 - 0
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -823,6 +823,7 @@ public:
   TEST_METHOD(CodeGenWriteToInput)
   TEST_METHOD(CodeGenWriteToInput2)
   TEST_METHOD(CodeGenWriteToInput3)
+  TEST_METHOD(CodeGenWriteToInput4)
 
   TEST_METHOD(CodeGenAttributes_Mod)
   TEST_METHOD(CodeGenConst_Exprb_Mod)
@@ -4399,6 +4400,10 @@ TEST_F(CompilerTest, CodeGenWriteToInput3) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\writeToInput3.hlsl");
 }
 
+TEST_F(CompilerTest, CodeGenWriteToInput4) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\writeToInput4.hlsl");
+}
+
 TEST_F(CompilerTest, CodeGenAttributes_Mod) {
   CodeGenTest(L"..\\CodeGenHLSL\\attributes_Mod.hlsl");
 }