Browse Source

Support memcpy to self. (#337)

Xiang Li 8 years ago
parent
commit
4af3ac8edd

+ 28 - 17
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -2425,12 +2425,24 @@ void MemcpySplitter::PatchMemCpyWithZeroIdxGEP(Module &M) {
   }
 }
 
-void MemcpySplitter::SplitMemCpy(MemCpyInst *MI, const DataLayout &DL,
-                                 DxilFieldAnnotation *fieldAnnotation,
-                                 DxilTypeSystem &typeSys) {
+static void DeleteMemcpy(MemCpyInst *MI) {
   Value *Op0 = MI->getOperand(0);
   Value *Op1 = MI->getOperand(1);
+  // delete memcpy
+  MI->eraseFromParent();
+  if (Instruction *op0 = dyn_cast<Instruction>(Op0)) {
+    if (op0->user_empty())
+      op0->eraseFromParent();
+  }
+  if (Instruction *op1 = dyn_cast<Instruction>(Op1)) {
+    if (op1->user_empty())
+      op1->eraseFromParent();
+  }
+}
 
+void MemcpySplitter::SplitMemCpy(MemCpyInst *MI, const DataLayout &DL,
+                                 DxilFieldAnnotation *fieldAnnotation,
+                                 DxilTypeSystem &typeSys) {
   Value *Dest = MI->getRawDest();
   Value *Src = MI->getRawSource();
   // Only remove one level bitcast generated from inline.
@@ -2439,6 +2451,12 @@ void MemcpySplitter::SplitMemCpy(MemCpyInst *MI, const DataLayout &DL,
   if (BitCastOperator *BC = dyn_cast<BitCastOperator>(Src))
     Src = BC->getOperand(0);
 
+  if (Dest == Src) {
+    // delete self copy.
+    DeleteMemcpy(MI);
+    return;
+  }
+
   IRBuilder<> Builder(MI);
   Type *DestTy = Dest->getType()->getPointerElementType();
   Type *SrcTy = Src->getType()->getPointerElementType();
@@ -2455,15 +2473,7 @@ void MemcpySplitter::SplitMemCpy(MemCpyInst *MI, const DataLayout &DL,
   SplitCpy(Dest->getType(), Dest, Src, idxList, Builder, typeSys,
            fieldAnnotation);
   // delete memcpy
-  MI->eraseFromParent();
-  if (Instruction *op0 = dyn_cast<Instruction>(Op0)) {
-    if (op0->user_empty())
-      op0->eraseFromParent();
-  }
-  if (Instruction *op1 = dyn_cast<Instruction>(Op1)) {
-    if (op1->user_empty())
-      op1->eraseFromParent();
-  }
+  DeleteMemcpy(MI);
 }
 
 void MemcpySplitter::Split(llvm::Function &F) {
@@ -3484,7 +3494,7 @@ struct PointerStatus {
   /// this global, keep track of what value it is.
   Value *StoredOnceValue;
   /// Memcpy which this ptr is used.
-  std::vector<MemCpyInst *> memcpyList;
+  std::unordered_set<MemCpyInst *> memcpySet;
   /// Memcpy which use this ptr as dest.
   MemCpyInst *StoringMemcpy;
   /// Memcpy which use this ptr as src.
@@ -3540,7 +3550,8 @@ void PointerStatus::analyzePointer(const Value *V, PointerStatus &PS,
       // Do not collect memcpy on struct GEP use.
       // These memcpy will be flattened in next level.
       if (!bStructElt) {
-        PS.memcpyList.emplace_back(const_cast<MemCpyInst *>(MC));
+        MemCpyInst *MI = const_cast<MemCpyInst *>(MC);
+        PS.memcpySet.insert(MI);
         bool bFullCopy = false;
         if (ConstantInt *Length = dyn_cast<ConstantInt>(MC->getLength())) {
           bFullCopy = PS.Size == Length->getLimitedValue();
@@ -3549,7 +3560,7 @@ void PointerStatus::analyzePointer(const Value *V, PointerStatus &PS,
           if (bFullCopy &&
               PS.StoredType == PointerStatus::StoredType::NotStored) {
             PS.StoredType = PointerStatus::StoredType::MemcopyDestOnce;
-            PS.StoringMemcpy = PS.memcpyList.back();
+            PS.StoringMemcpy = MI;
           } else {
             PS.MarkAsStored();
             PS.StoringMemcpy = nullptr;
@@ -3558,7 +3569,7 @@ void PointerStatus::analyzePointer(const Value *V, PointerStatus &PS,
           if (bFullCopy &&
               PS.LoadedType == PointerStatus::LoadedType::NotLoaded) {
             PS.LoadedType = PointerStatus::LoadedType::MemcopySrcOnce;
-            PS.LoadingMemcpy = PS.memcpyList.back();
+            PS.LoadingMemcpy = MI;
           } else {
             PS.MarkAsLoaded();
             PS.LoadingMemcpy = nullptr;
@@ -3784,7 +3795,7 @@ bool SROA_Helper::LowerMemcpy(Value *V, DxilFieldAnnotation *annotation,
     }
   }
 
-  for (MemCpyInst *MC : PS.memcpyList) {
+  for (MemCpyInst *MC : PS.memcpySet) {
     MemcpySplitter::SplitMemCpy(MC, DL, annotation, typeSys);
   }
   return false;

+ 19 - 0
tools/clang/test/CodeGenHLSL/self_copy.hlsl

@@ -0,0 +1,19 @@
+// RUN: %dxc -E main -T ps_6_0 %s
+
+struct N {
+  float n;
+};
+
+struct S {
+  float s;
+  N  n;
+};
+
+S s0;
+
+float4 main(float4 a : A, float4 b:B) : SV_TARGET
+{
+  S s1 = s0;
+  s1 = s1;
+  return s1.n.n;
+}

+ 5 - 0
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -606,6 +606,7 @@ public:
   TEST_METHOD(CodeGenSelectObj3)
   TEST_METHOD(CodeGenSelectObj4)
   TEST_METHOD(CodeGenSelectObj5)
+  TEST_METHOD(CodeGenSelfCopy)
   TEST_METHOD(CodeGenSelMat)
   TEST_METHOD(CodeGenShare_Mem_Dbg)
   TEST_METHOD(CodeGenShare_Mem_Phi)
@@ -3179,6 +3180,10 @@ TEST_F(CompilerTest, CodeGenSelectObj5) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\selectObj5.hlsl");
 }
 
+TEST_F(CompilerTest, CodeGenSelfCopy) {
+  CodeGenTest(L"..\\CodeGenHLSL\\self_copy.hlsl");
+}
+
 TEST_F(CompilerTest, CodeGenSelMat) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\selMat.hlsl");
 }