Browse Source

Correct DomTree usage in memcpy lowering (#2839)

* Correct DomTree usage in memcpy lowering

By copying the code meant for memcpy used on global variables, I forced
recalculation of the DominatorTree for each and every memcpy. This
was pretty slow and entirely unnecessary. The global variables are
invoked from SROA_HLSL_Parameters, which is a module pass that doesn't
really know which function its dealing with until it addresses the
actual memcpy. SROA_HLSL already depends on the domination tree analysis
pass which allows the pass to draw from that for the dominator tree.

In addition, copying from the original code made use of a Post Dominator
Tree. This was a mistake. What we really want to know is if the source
dominates the uses of the destination. In practice, rarely would
anything be found to post dominate unless they shared a block. In the
contrived case where a memcpy post dominated its destination, the way
dxilgen dealt with cbuffers made it irrelevant because they were all
replaced with the result of its initialization result in the entry.

Finally, as Xiang suggested, this now uses the source to determine
domination.

The result of these is faster compiles and faster code produced. No
incorrect behavior should have resulted from the previous code. So this
doesn't correct any.

Incidental removal of unused method in SROA_HLSL

* Avoid more problem memcpy replaces

The previous location for memcpy replace too narrow. It only handled one
way that memcpy replacement might be triggered. What's more, it was done
regardless of other things, potentially quicker to check that might
prevent the replacement. By moving the check into ReplaceMemcpy, it is
checked every time it needs to be and never when it doesn't.

Additionally, ReplaceMemcpy was called on memcpys that were created as
part of a memcpy split. This isn't a problem if they get replaced by
loads and stores, but if all uses of the dest are replaced by the src,
the memcpy is essentially removed. sub-memcpy params are only used
directly by their memcpy. So this effectively removes the sub memcpy
entirely. By detecting when the only users of the memcpy parameters are
that memcpy and refusing to replace it entirely, we evade this however
it arises.

This eliminates the earlier fix that disabled subelement memcpy when the
domination check failed. The effect is largely the same in many cases,
but makes the check more directly.
Greg Roth 5 years ago
parent
commit
9840380ef8

+ 72 - 115
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -84,27 +84,29 @@ public:
                                   IRBuilder<> &Builder, bool bFlatVector,
                                   IRBuilder<> &Builder, bool bFlatVector,
                                   bool hasPrecise, DxilTypeSystem &typeSys,
                                   bool hasPrecise, DxilTypeSystem &typeSys,
                                   const DataLayout &DL,
                                   const DataLayout &DL,
-                                  SmallVector<Value *, 32> &DeadInsts);
+                                  SmallVector<Value *, 32> &DeadInsts,
+                                  DominatorTree *DT);
 
 
   static bool DoScalarReplacement(GlobalVariable *GV, std::vector<Value *> &Elts,
   static bool DoScalarReplacement(GlobalVariable *GV, std::vector<Value *> &Elts,
                                   IRBuilder<> &Builder, bool bFlatVector,
                                   IRBuilder<> &Builder, bool bFlatVector,
                                   bool hasPrecise, DxilTypeSystem &typeSys,
                                   bool hasPrecise, DxilTypeSystem &typeSys,
                                   const DataLayout &DL,
                                   const DataLayout &DL,
-                                  SmallVector<Value *, 32> &DeadInsts);
+                                  SmallVector<Value *, 32> &DeadInsts,
+                                  DominatorTree *DT);
   static unsigned GetEltAlign(unsigned ValueAlign, const DataLayout &DL,
   static unsigned GetEltAlign(unsigned ValueAlign, const DataLayout &DL,
                               Type *EltTy, unsigned Offset);
                               Type *EltTy, unsigned Offset);
   // Lower memcpy related to V.
   // Lower memcpy related to V.
   static bool LowerMemcpy(Value *V, DxilFieldAnnotation *annotation,
   static bool LowerMemcpy(Value *V, DxilFieldAnnotation *annotation,
                           DxilTypeSystem &typeSys, const DataLayout &DL,
                           DxilTypeSystem &typeSys, const DataLayout &DL,
-                          bool bAllowReplace);
+                          DominatorTree *DT, bool bAllowReplace);
   static void MarkEmptyStructUsers(Value *V,
   static void MarkEmptyStructUsers(Value *V,
                                    SmallVector<Value *, 32> &DeadInsts);
                                    SmallVector<Value *, 32> &DeadInsts);
   static bool IsEmptyStructType(Type *Ty, DxilTypeSystem &typeSys);
   static bool IsEmptyStructType(Type *Ty, DxilTypeSystem &typeSys);
 private:
 private:
   SROA_Helper(Value *V, ArrayRef<Value *> Elts,
   SROA_Helper(Value *V, ArrayRef<Value *> Elts,
               SmallVector<Value *, 32> &DeadInsts, DxilTypeSystem &ts,
               SmallVector<Value *, 32> &DeadInsts, DxilTypeSystem &ts,
-              const DataLayout &dl)
-      : OldVal(V), NewElts(Elts), DeadInsts(DeadInsts), typeSys(ts), DL(dl) {}
+              const DataLayout &dl, DominatorTree *dt)
+    : OldVal(V), NewElts(Elts), DeadInsts(DeadInsts), typeSys(ts), DL(dl), DT(dt) {}
   void RewriteForScalarRepl(Value *V, IRBuilder<> &Builder);
   void RewriteForScalarRepl(Value *V, IRBuilder<> &Builder);
 
 
 private:
 private:
@@ -115,6 +117,7 @@ private:
   SmallVector<Value *, 32> &DeadInsts;
   SmallVector<Value *, 32> &DeadInsts;
   DxilTypeSystem  &typeSys;
   DxilTypeSystem  &typeSys;
   const DataLayout &DL;
   const DataLayout &DL;
+  DominatorTree *DT;
 
 
   void RewriteForConstExpr(ConstantExpr *user, IRBuilder<> &Builder);
   void RewriteForConstExpr(ConstantExpr *user, IRBuilder<> &Builder);
   void RewriteForGEP(GEPOperator *GEP, IRBuilder<> &Builder);
   void RewriteForGEP(GEPOperator *GEP, IRBuilder<> &Builder);
@@ -145,7 +148,6 @@ struct SROA_HLSL : public FunctionPass {
   bool runOnFunction(Function &F) override;
   bool runOnFunction(Function &F) override;
 
 
   bool performScalarRepl(Function &F, DxilTypeSystem &typeSys);
   bool performScalarRepl(Function &F, DxilTypeSystem &typeSys);
-  bool performPromotion(Function &F);
   bool markPrecise(Function &F);
   bool markPrecise(Function &F);
 
 
 private:
 private:
@@ -230,7 +232,7 @@ private:
   bool ShouldAttemptScalarRepl(AllocaInst *AI);
   bool ShouldAttemptScalarRepl(AllocaInst *AI);
 };
 };
 
 
-// SROA_DT - SROA that uses DominatorTree.
+// SROA_DT_HLSL - SROA that uses DominatorTree.
 struct SROA_DT_HLSL : public SROA_HLSL {
 struct SROA_DT_HLSL : public SROA_HLSL {
   static char ID;
   static char ID;
 
 
@@ -710,58 +712,6 @@ static bool tryToMakeAllocaBePromotable(AllocaInst *AI, const DataLayout &DL) {
   return true;
   return true;
 }
 }
 
 
-bool SROA_HLSL::performPromotion(Function &F) {
-  std::vector<AllocaInst *> Allocas;
-  const DataLayout &DL = F.getParent()->getDataLayout();
-  DominatorTree *DT = nullptr;
-  if (HasDomTree)
-    DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
-  AssumptionCache &AC =
-      getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
-
-  BasicBlock &BB = F.getEntryBlock(); // Get the entry node for the function
-  DIBuilder DIB(*F.getParent(), /*AllowUnresolved*/ false);
-  bool Changed = false;
-  SmallVector<Instruction *, 64> Insts;
-  while (1) {
-    Allocas.clear();
-
-    // Find allocas that are safe to promote, by looking at all instructions in
-    // the entry node
-    for (BasicBlock::iterator I = BB.begin(), E = --BB.end(); I != E; ++I)
-      if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) { // Is it an alloca?
-        DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(AI);
-        // Skip alloca has debug info when not promote.
-        if (DDI && !RunPromotion) {
-          continue;
-        }
-        if (tryToMakeAllocaBePromotable(AI, DL))
-          Allocas.push_back(AI);
-      }
-    if (Allocas.empty())
-      break;
-
-    if (HasDomTree)
-      PromoteMemToReg(Allocas, *DT, nullptr, &AC);
-    else {
-      SSAUpdater SSA;
-      for (unsigned i = 0, e = Allocas.size(); i != e; ++i) {
-        AllocaInst *AI = Allocas[i];
-
-        // Build list of instructions to promote.
-        for (User *U : AI->users())
-          Insts.push_back(cast<Instruction>(U));
-        AllocaPromoter(Insts, SSA, &DIB).run(AI, Insts);
-        Insts.clear();
-      }
-    }
-    NumPromoted += Allocas.size();
-    Changed = true;
-  }
-
-  return Changed;
-}
-
 /// ShouldAttemptScalarRepl - Decide if an alloca is a good candidate for
 /// ShouldAttemptScalarRepl - Decide if an alloca is a good candidate for
 /// SROA.  It must be a struct or array type with a small number of elements.
 /// SROA.  It must be a struct or array type with a small number of elements.
 bool SROA_HLSL::ShouldAttemptScalarRepl(AllocaInst *AI) {
 bool SROA_HLSL::ShouldAttemptScalarRepl(AllocaInst *AI) {
@@ -1146,8 +1096,9 @@ bool SROA_HLSL::performScalarRepl(Function &F, DxilTypeSystem &typeSys) {
       continue;
       continue;
     }
     }
     const bool bAllowReplace = true;
     const bool bAllowReplace = true;
+    DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
     if (SROA_Helper::LowerMemcpy(AI, /*annotation*/ nullptr, typeSys, DL,
     if (SROA_Helper::LowerMemcpy(AI, /*annotation*/ nullptr, typeSys, DL,
-                                 bAllowReplace)) {
+                                 DT, bAllowReplace)) {
       Changed = true;
       Changed = true;
       continue;
       continue;
     }
     }
@@ -1194,9 +1145,10 @@ bool SROA_HLSL::performScalarRepl(Function &F, DxilTypeSystem &typeSys) {
 
 
       Type *BrokenUpTy = nullptr;
       Type *BrokenUpTy = nullptr;
       uint64_t NumInstances = 1;
       uint64_t NumInstances = 1;
+      DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
       bool SROAed = SROA_Helper::DoScalarReplacement(
       bool SROAed = SROA_Helper::DoScalarReplacement(
         AI, Elts, BrokenUpTy, NumInstances, Builder,
         AI, Elts, BrokenUpTy, NumInstances, Builder,
-        /*bFlatVector*/ true, hasPrecise, typeSys, DL, DeadInsts);
+        /*bFlatVector*/ true, hasPrecise, typeSys, DL, DeadInsts, DT);
 
 
       if (SROAed) {
       if (SROAed) {
         Type *Ty = AI->getAllocatedType();
         Type *Ty = AI->getAllocatedType();
@@ -2269,8 +2221,8 @@ void SROA_Helper::RewriteForGEP(GEPOperator *GEP, IRBuilder<> &Builder) {
         NewGEPs.emplace_back(NewGEP);
         NewGEPs.emplace_back(NewGEP);
       }
       }
       const bool bAllowReplace = isa<AllocaInst>(OldVal);
       const bool bAllowReplace = isa<AllocaInst>(OldVal);
-      if (!SROA_Helper::LowerMemcpy(GEP, /*annoation*/ nullptr, typeSys, DL, bAllowReplace)) {
-        SROA_Helper helper(GEP, NewGEPs, DeadInsts, typeSys, DL);
+      if (!SROA_Helper::LowerMemcpy(GEP, /*annoation*/ nullptr, typeSys, DL, DT, bAllowReplace)) {
+        SROA_Helper helper(GEP, NewGEPs, DeadInsts, typeSys, DL, DT);
         helper.RewriteForScalarRepl(GEP, Builder);
         helper.RewriteForScalarRepl(GEP, Builder);
         for (Value *NewGEP : NewGEPs) {
         for (Value *NewGEP : NewGEPs) {
           if (NewGEP->user_empty() && isa<Instruction>(NewGEP)) {
           if (NewGEP->user_empty() && isa<Instruction>(NewGEP)) {
@@ -2924,7 +2876,7 @@ void SROA_Helper::RewriteForAddrSpaceCast(Value *CE,
                          CE->getType()->getPointerAddressSpace()));
                          CE->getType()->getPointerAddressSpace()));
     NewCasts.emplace_back(NewCast);
     NewCasts.emplace_back(NewCast);
   }
   }
-  SROA_Helper helper(CE, NewCasts, DeadInsts, typeSys, DL);
+  SROA_Helper helper(CE, NewCasts, DeadInsts, typeSys, DL, DT);
   helper.RewriteForScalarRepl(CE, Builder);
   helper.RewriteForScalarRepl(CE, Builder);
 
 
   // Remove the use so that the caller can keep iterating over its other users
   // Remove the use so that the caller can keep iterating over its other users
@@ -3027,7 +2979,8 @@ bool SROA_Helper::DoScalarReplacement(Value *V, std::vector<Value *> &Elts,
                                       IRBuilder<> &Builder, bool bFlatVector,
                                       IRBuilder<> &Builder, bool bFlatVector,
                                       bool hasPrecise, DxilTypeSystem &typeSys,
                                       bool hasPrecise, DxilTypeSystem &typeSys,
                                       const DataLayout &DL,
                                       const DataLayout &DL,
-                                      SmallVector<Value *, 32> &DeadInsts) {
+                                      SmallVector<Value *, 32> &DeadInsts,
+                                      DominatorTree *DT) {
   DEBUG(dbgs() << "Found inst to SROA: " << *V << '\n');
   DEBUG(dbgs() << "Found inst to SROA: " << *V << '\n');
   Type *Ty = V->getType();
   Type *Ty = V->getType();
   // Skip none pointer types.
   // Skip none pointer types.
@@ -3157,7 +3110,7 @@ bool SROA_Helper::DoScalarReplacement(Value *V, std::vector<Value *> &Elts,
   
   
   // Now that we have created the new alloca instructions, rewrite all the
   // Now that we have created the new alloca instructions, rewrite all the
   // uses of the old alloca.
   // uses of the old alloca.
-  SROA_Helper helper(V, Elts, DeadInsts, typeSys, DL);
+  SROA_Helper helper(V, Elts, DeadInsts, typeSys, DL, DT);
   helper.RewriteForScalarRepl(V, Builder);
   helper.RewriteForScalarRepl(V, Builder);
 
 
   return true;
   return true;
@@ -3220,7 +3173,8 @@ bool SROA_Helper::DoScalarReplacement(GlobalVariable *GV,
                                       IRBuilder<> &Builder, bool bFlatVector,
                                       IRBuilder<> &Builder, bool bFlatVector,
                                       bool hasPrecise, DxilTypeSystem &typeSys,
                                       bool hasPrecise, DxilTypeSystem &typeSys,
                                       const DataLayout &DL,
                                       const DataLayout &DL,
-                                      SmallVector<Value *, 32> &DeadInsts) {
+                                      SmallVector<Value *, 32> &DeadInsts,
+                                      DominatorTree *DT) {
   DEBUG(dbgs() << "Found inst to SROA: " << *GV << '\n');
   DEBUG(dbgs() << "Found inst to SROA: " << *GV << '\n');
   Type *Ty = GV->getType();
   Type *Ty = GV->getType();
   // Skip none pointer types.
   // Skip none pointer types.
@@ -3369,7 +3323,7 @@ bool SROA_Helper::DoScalarReplacement(GlobalVariable *GV,
 
 
   // Now that we have created the new alloca instructions, rewrite all the
   // Now that we have created the new alloca instructions, rewrite all the
   // uses of the old alloca.
   // uses of the old alloca.
-  SROA_Helper helper(GV, Elts, DeadInsts, typeSys, DL);
+  SROA_Helper helper(GV, Elts, DeadInsts, typeSys, DL, DT);
   helper.RewriteForScalarRepl(GV, Builder);
   helper.RewriteForScalarRepl(GV, Builder);
 
 
   return true;
   return true;
@@ -3471,9 +3425,25 @@ static void CopyElementsOfStructsWithIdenticalLayout(
   }
   }
 }
 }
 
 
-static void ReplaceMemcpy(Value *V, Value *Src, MemCpyInst *MC,
-                          DxilFieldAnnotation *annotation,
-                          DxilTypeSystem &typeSys, const DataLayout &DL) {
+static bool DominateAllUsers(Instruction *I, Value *V, DominatorTree *DT);
+
+
+static bool ReplaceMemcpy(Value *V, Value *Src, MemCpyInst *MC,
+                          DxilFieldAnnotation *annotation, DxilTypeSystem &typeSys,
+                          const DataLayout &DL, DominatorTree *DT) {
+  // If the only user of the src and dst is the memcpy,
+  // this memcpy was probably produced by splitting another.
+  // Regardless, the goal here is to replace, not remove the memcpy
+  // we won't have enough information to determine if we can do that before mem2reg
+  if (V != Src && V->hasOneUse() && Src->hasOneUse())
+    return false;
+
+  // If the memcpy doesn't dominate all its users,
+  // full replacement isn't possible without complicated PHI insertion
+  // This will likely replace with ld/st which will be replaced in mem2reg
+  if (Instruction *SrcI = dyn_cast<Instruction>(Src))
+    if (!DominateAllUsers(SrcI, V, DT))
+      return false;
   Type *TyV = V->getType()->getPointerElementType();
   Type *TyV = V->getType()->getPointerElementType();
   Type *TySrc = Src->getType()->getPointerElementType();
   Type *TySrc = Src->getType()->getPointerElementType();
   if (Constant *C = dyn_cast<Constant>(V)) {
   if (Constant *C = dyn_cast<Constant>(V)) {
@@ -3499,7 +3469,7 @@ static void ReplaceMemcpy(Value *V, Value *Src, MemCpyInst *MC,
       Value* SrcVal = MC->getRawSource();
       Value* SrcVal = MC->getRawSource();
       if (!isa<BitCastInst>(SrcVal) || !isa<BitCastInst>(DestVal)) {
       if (!isa<BitCastInst>(SrcVal) || !isa<BitCastInst>(DestVal)) {
         DXASSERT(0, "Encountered unexpected instruction sequence");
         DXASSERT(0, "Encountered unexpected instruction sequence");
-        return;
+        return false;
       }
       }
 
 
       BitCastInst *DestBCI = cast<BitCastInst>(DestVal);
       BitCastInst *DestBCI = cast<BitCastInst>(DestVal);
@@ -3514,7 +3484,7 @@ static void ReplaceMemcpy(Value *V, Value *Src, MemCpyInst *MC,
         unsigned MemcpySize = cast<ConstantInt>(MC->getLength())->getZExtValue();
         unsigned MemcpySize = cast<ConstantInt>(MC->getLength())->getZExtValue();
         if (SrcSize != MemcpySize) {
         if (SrcSize != MemcpySize) {
           DXASSERT(0, "Cannot handle partial memcpy");
           DXASSERT(0, "Cannot handle partial memcpy");
-          return;
+          return false;
         }
         }
 
 
         if (DestBCI->hasOneUse() && SrcBCI->hasOneUse()) {
         if (DestBCI->hasOneUse() && SrcBCI->hasOneUse()) {
@@ -3532,13 +3502,13 @@ static void ReplaceMemcpy(Value *V, Value *Src, MemCpyInst *MC,
           Value *SrcPtr = SrcBCI->getOperand(0);
           Value *SrcPtr = SrcBCI->getOperand(0);
           if (isa<GEPOperator>(DstPtr) || isa<GEPOperator>(SrcPtr)) {
           if (isa<GEPOperator>(DstPtr) || isa<GEPOperator>(SrcPtr)) {
             MemcpySplitter::SplitMemCpy(MC, DL, annotation, typeSys);
             MemcpySplitter::SplitMemCpy(MC, DL, annotation, typeSys);
-            return;
+            return true;
           } else {
           } else {
             DstPtr->replaceAllUsesWith(SrcPtr);
             DstPtr->replaceAllUsesWith(SrcPtr);
           }
           }
         } else {
         } else {
           DXASSERT(0, "Can't handle structs of different layouts");
           DXASSERT(0, "Can't handle structs of different layouts");
-          return;
+          return false;
         }
         }
       }
       }
     } else {
     } else {
@@ -3567,6 +3537,8 @@ static void ReplaceMemcpy(Value *V, Value *Src, MemCpyInst *MC,
     if (I->user_empty())
     if (I->user_empty())
       I->eraseFromParent();
       I->eraseFromParent();
   }
   }
+
+  return true;
 }
 }
 
 
 static bool ReplaceUseOfZeroInitEntry(Instruction *I, Value *V) {
 static bool ReplaceUseOfZeroInitEntry(Instruction *I, Value *V) {
@@ -3653,22 +3625,19 @@ static bool ReplaceUseOfZeroInitBeforeDef(Instruction *I, GlobalVariable *GV) {
   }
   }
 }
 }
 
 
-
-static bool DominateAllUsersPostDom(Instruction *I, Value *V,
-                                    PostDominatorTree &PDT) {
+static bool DominateAllUsersDom(Instruction *I, Value *V, DominatorTree *DT) {
   BasicBlock *BB = I->getParent();
   BasicBlock *BB = I->getParent();
-  Function *F = I->getParent()->getParent();
   for (auto U = V->user_begin(); U != V->user_end(); ) {
   for (auto U = V->user_begin(); U != V->user_end(); ) {
     Instruction *UI = dyn_cast<Instruction>(*(U++));
     Instruction *UI = dyn_cast<Instruction>(*(U++));
     if (!UI)
     if (!UI)
       continue;
       continue;
-    assert (UI->getParent()->getParent() == F);
+    assert (UI->getParent()->getParent() == I->getParent()->getParent());
 
 
-    if (!PDT.dominates(BB, UI->getParent()))
+    if (!DT->dominates(BB, UI->getParent()))
       return false;
       return false;
 
 
     if (isa<GetElementPtrInst>(UI) || isa<BitCastInst>(UI)) {
     if (isa<GetElementPtrInst>(UI) || isa<BitCastInst>(UI)) {
-      if (!DominateAllUsersPostDom(I, UI, PDT))
+      if (!DominateAllUsersDom(I, UI, DT))
         return false;
         return false;
     }
     }
   }
   }
@@ -3676,23 +3645,19 @@ static bool DominateAllUsersPostDom(Instruction *I, Value *V,
 }
 }
 
 
 // Determine if `I` dominates all the users of `V`
 // Determine if `I` dominates all the users of `V`
-static bool DominateAllUsers(Instruction *I, Value *V) {
+static bool DominateAllUsers(Instruction *I, Value *V, DominatorTree *DT) {
   Function *F = I->getParent()->getParent();
   Function *F = I->getParent()->getParent();
 
 
   // The Entry Block dominates everything, trivially true
   // The Entry Block dominates everything, trivially true
   if (&F->getEntryBlock() == I->getParent())
   if (&F->getEntryBlock() == I->getParent())
     return true;
     return true;
 
 
-  // Post dominator tree.
-  PostDominatorTree PDT;
-  PDT.runOnFunction(*F);
-  return DominateAllUsersPostDom(I, V, PDT);
+  return DominateAllUsersDom(I, V, DT);
 }
 }
 
 
-
 bool SROA_Helper::LowerMemcpy(Value *V, DxilFieldAnnotation *annotation,
 bool SROA_Helper::LowerMemcpy(Value *V, DxilFieldAnnotation *annotation,
                               DxilTypeSystem &typeSys, const DataLayout &DL,
                               DxilTypeSystem &typeSys, const DataLayout &DL,
-                              bool bAllowReplace) {
+                              DominatorTree *DT, bool bAllowReplace) {
   Type *Ty = V->getType();
   Type *Ty = V->getType();
   if (!Ty->isPointerTy()) {
   if (!Ty->isPointerTy()) {
     return false;
     return false;
@@ -3703,7 +3668,6 @@ bool SROA_Helper::LowerMemcpy(Value *V, DxilFieldAnnotation *annotation,
   unsigned size = DL.getTypeAllocSize(Ty->getPointerElementType());
   unsigned size = DL.getTypeAllocSize(Ty->getPointerElementType());
   hlutil::PointerStatus PS(V, size, /*bLdStOnly*/ false);
   hlutil::PointerStatus PS(V, size, /*bLdStOnly*/ false);
   const bool bStructElt = false;
   const bool bStructElt = false;
-  bool bEltMemcpy = true;
   PS.analyze(typeSys, bStructElt);
   PS.analyze(typeSys, bStructElt);
 
 
   if (GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) {
   if (GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) {
@@ -3730,17 +3694,6 @@ bool SROA_Helper::LowerMemcpy(Value *V, DxilFieldAnnotation *annotation,
         PS.storedType = hlutil::PointerStatus::StoredType::Stored;
         PS.storedType = hlutil::PointerStatus::StoredType::Stored;
       }
       }
     }
     }
-  } else if (PS.storedType ==
-             hlutil::PointerStatus::StoredType::MemcopyDestOnce) {
-    // As above, it the memcpy doesn't dominate all its users,
-    // full replacement isn't possible without complicated PHI insertion
-    // This will likely replace with ld/st which will be replaced in mem2reg
-    Instruction *Memcpy = PS.StoringMemcpy;
-    if (!DominateAllUsers(Memcpy, V)) {
-      PS.storedType = hlutil::PointerStatus::StoredType::Stored;
-      // Replacing a memcpy with a memcpy with the same signature will just bring us back here
-      bEltMemcpy = false;
-    }
   }
   }
 
 
   if (bAllowReplace && !PS.HasMultipleAccessingFunctions) {
   if (bAllowReplace && !PS.HasMultipleAccessingFunctions) {
@@ -3770,8 +3723,8 @@ bool SROA_Helper::LowerMemcpy(Value *V, DxilFieldAnnotation *annotation,
                   static_cast<HLSubscriptOpcode>(hlsl::GetHLOpcode(PtrCI));
                   static_cast<HLSubscriptOpcode>(hlsl::GetHLOpcode(PtrCI));
               if (opcode == HLSubscriptOpcode::CBufferSubscript) {
               if (opcode == HLSubscriptOpcode::CBufferSubscript) {
                 // Ptr from CBuffer is safe.
                 // Ptr from CBuffer is safe.
-                ReplaceMemcpy(V, Src, MC, annotation, typeSys, DL);
-                return true;
+                if (ReplaceMemcpy(V, Src, MC, annotation, typeSys, DL, DT))
+                  return true;
               }
               }
             }
             }
           }
           }
@@ -3782,8 +3735,8 @@ bool SROA_Helper::LowerMemcpy(Value *V, DxilFieldAnnotation *annotation,
           hlutil::PointerStatus SrcPS(Src, size, /*bLdStOnly*/ false);
           hlutil::PointerStatus SrcPS(Src, size, /*bLdStOnly*/ false);
           SrcPS.analyze(typeSys, bStructElt);
           SrcPS.analyze(typeSys, bStructElt);
           if (SrcPS.storedType != hlutil::PointerStatus::StoredType::Stored) {
           if (SrcPS.storedType != hlutil::PointerStatus::StoredType::Stored) {
-            ReplaceMemcpy(V, Src, MC, annotation, typeSys, DL);
-            return true;
+            if (ReplaceMemcpy(V, Src, MC, annotation, typeSys, DL, DT))
+              return true;
           }
           }
         }
         }
       }
       }
@@ -3806,10 +3759,11 @@ bool SROA_Helper::LowerMemcpy(Value *V, DxilFieldAnnotation *annotation,
           hlutil::PointerStatus DestPS(Dest, size, /*bLdStOnly*/ false);
           hlutil::PointerStatus DestPS(Dest, size, /*bLdStOnly*/ false);
           DestPS.analyze(typeSys, bStructElt);
           DestPS.analyze(typeSys, bStructElt);
           if (DestPS.storedType != hlutil::PointerStatus::StoredType::Stored) {
           if (DestPS.storedType != hlutil::PointerStatus::StoredType::Stored) {
-            ReplaceMemcpy(Dest, V, MC, annotation, typeSys, DL);
-            // V still need to be flatten.
-            // Lower memcpy come from Dest.
-            return LowerMemcpy(V, annotation, typeSys, DL, bAllowReplace);
+            if (ReplaceMemcpy(Dest, V, MC, annotation, typeSys, DL, DT)) {
+              // V still needs to be flattened.
+              // Lower memcpy come from Dest.
+              return LowerMemcpy(V, annotation, typeSys, DL, DT, bAllowReplace);
+            }
           }
           }
         }
         }
       }
       }
@@ -3817,7 +3771,7 @@ bool SROA_Helper::LowerMemcpy(Value *V, DxilFieldAnnotation *annotation,
   }
   }
 
 
   for (MemCpyInst *MC : PS.memcpySet) {
   for (MemCpyInst *MC : PS.memcpySet) {
-    MemcpySplitter::SplitMemCpy(MC, DL, annotation, typeSys, bEltMemcpy);
+    MemcpySplitter::SplitMemCpy(MC, DL, annotation, typeSys);
   }
   }
   return false;
   return false;
 }
 }
@@ -4223,10 +4177,10 @@ void SROA_Parameter_HLSL::flattenGlobal(GlobalVariable *GV) {
   while (!WorkList.empty()) {
   while (!WorkList.empty()) {
     GlobalVariable *EltGV = cast<GlobalVariable>(WorkList.front());
     GlobalVariable *EltGV = cast<GlobalVariable>(WorkList.front());
     WorkList.pop_front();
     WorkList.pop_front();
-
     const bool bAllowReplace = true;
     const bool bAllowReplace = true;
+    // Globals don't need DomTree here because they take another path
     if (SROA_Helper::LowerMemcpy(EltGV, /*annoation*/ nullptr, dxilTypeSys, DL,
     if (SROA_Helper::LowerMemcpy(EltGV, /*annoation*/ nullptr, dxilTypeSys, DL,
-                                 bAllowReplace)) {
+                                 nullptr /*DT */, bAllowReplace)) {
       continue;
       continue;
     }
     }
 
 
@@ -4250,10 +4204,11 @@ void SROA_Parameter_HLSL::flattenGlobal(GlobalVariable *GV) {
       }
       }
       EltGV = NewEltGV;
       EltGV = NewEltGV;
     } else {
     } else {
+      // Globals don't need DomTree
       SROAed = SROA_Helper::DoScalarReplacement(
       SROAed = SROA_Helper::DoScalarReplacement(
           EltGV, Elts, Builder, bFlatVector,
           EltGV, Elts, Builder, bFlatVector,
           // TODO: set precise.
           // TODO: set precise.
-          /*hasPrecise*/ false, dxilTypeSys, DL, DeadInsts);
+          /*hasPrecise*/ false, dxilTypeSys, DL, DeadInsts, /*DT*/ nullptr);
     }
     }
 
 
     if (SROAed) {
     if (SROAed) {
@@ -5049,8 +5004,9 @@ void SROA_Parameter_HLSL::flattenArgument(
     // We can never replace memcpy for arguments because they have an implicit
     // We can never replace memcpy for arguments because they have an implicit
     // first memcpy that happens from argument passing, and pointer analysis
     // first memcpy that happens from argument passing, and pointer analysis
     // will not reveal that, especially if we've done a first SROA pass on V.
     // will not reveal that, especially if we've done a first SROA pass on V.
+    // No DomTree needed for that reason
     const bool bAllowReplace = false;
     const bool bAllowReplace = false;
-    SROA_Helper::LowerMemcpy(V, &annotation, dxilTypeSys, DL, bAllowReplace);
+    SROA_Helper::LowerMemcpy(V, &annotation, dxilTypeSys, DL, nullptr /*DT */, bAllowReplace);
 
 
     // Now is safe to create the IRBuilders.
     // Now is safe to create the IRBuilders.
     // If we create it before LowerMemcpy, the insertion pointer instruction may get deleted
     // If we create it before LowerMemcpy, the insertion pointer instruction may get deleted
@@ -5064,10 +5020,11 @@ void SROA_Parameter_HLSL::flattenArgument(
     Type *BrokenUpTy = nullptr;
     Type *BrokenUpTy = nullptr;
     uint64_t NumInstances = 1;
     uint64_t NumInstances = 1;
     if (inputQual != DxilParamInputQual::InPayload) {
     if (inputQual != DxilParamInputQual::InPayload) {
+      // DomTree isn't used by arguments
       SROAed = SROA_Helper::DoScalarReplacement(
       SROAed = SROA_Helper::DoScalarReplacement(
         V, Elts, BrokenUpTy, NumInstances, Builder, 
         V, Elts, BrokenUpTy, NumInstances, Builder, 
         /*bFlatVector*/ false, annotation.IsPrecise(),
         /*bFlatVector*/ false, annotation.IsPrecise(),
-        dxilTypeSys, DL, DeadInsts);
+        dxilTypeSys, DL, DeadInsts, /*DT*/ nullptr);
     }
     }
 
 
     if (SROAed) {
     if (SROAed) {

+ 43 - 0
tools/clang/test/HLSLFileCheck/passes/hl/sroa_hlsl/memcpy_domuser.hlsl

@@ -0,0 +1,43 @@
+// RUN: %dxc -T ps_6_0 %s | FileCheck %s
+
+// Test memcpy with src that dominates, but doesn't postdominate its uses
+
+// This produces suboptimal code when postdom check is used
+// When corrected, RAUW can be used
+
+// CHECK: @main
+// CHECK-NOT: memcpy
+// CHECK-NOT: = br
+// CHECK: icmp slt i32
+// CHECK: icmp sgt i32
+// CHECK: and i1
+// CHECK: br i1
+// CHECK: cbufferLoadLegacy
+// CHECK: fadd
+// CHECK: ret void
+struct OuterStruct
+{
+  float fval;
+  float fval2;
+};
+
+cbuffer cbuf : register(b1)
+{
+ OuterStruct g_oStruct[1];
+};
+
+float main(int doit : A, int dontit: B) : SV_Target
+{
+  OuterStruct oStruct;
+  float addon = 0.0;
+  // break up blocks so not in entry
+  // RAUW is able to combine these conditionals with an AND
+  // ld/str inserts loads between them at least temporarily
+  if (doit < 10) {
+    oStruct = g_oStruct[doit];
+    if (dontit > 0)
+      addon += oStruct.fval2 + oStruct.fval; // dominated, but not post dominated by memcpy. so it gets ld/str
+  }
+
+  return addon;
+}

+ 49 - 0
tools/clang/test/HLSLFileCheck/passes/hl/sroa_hlsl/memcpy_preuser.hlsl

@@ -0,0 +1,49 @@
+// RUN: %dxc -T ps_6_0 %s | FileCheck %s
+
+// A test that has the only user of a memcpy before it
+// Meant to test the differnce between dominator trees
+
+// this could theoretically cause a big problem by using postdom instead of dom
+// to determine if a memcpy src can replace all the dests.
+// The loop after if postdoms the if, so the src of that memcpy
+// replaces the ostruct.fval2 in the if
+// however, the way dxilgen handles cbuffers, it fixes it all up anyway.
+// Nevertheless, the weirdo backward dependency this creates results
+// in more complicated code when RAUW is incorrectly used rather than ld/str
+// Including placing a lot of the cbuffer loading in the inner if block
+
+// CHECK: @main
+// CHECK-NOT: memcpy
+
+// broken case will have no select and cbuffer load will precede fadd
+// CHECK: fadd
+// CHECK: select i1
+// CHECK: cbufferLoadLegacy
+// CHECK: ret void
+struct OuterStruct
+{
+  float fval;
+  float fval2;
+};
+
+cbuffer cbuf : register(b1)
+{
+ OuterStruct g_oStruct[1];
+};
+
+float main(int doit : A) : SV_Target
+{
+  float res = 0.0;
+  // Need a loop so the dest user can come before the memcpy
+  for (int i = 0; i < doit; i++) {
+    OuterStruct oStruct;
+    // This should be expressable as a select unless a bunch of mem stuff gets crammed in
+    if(i%2 == 0) {
+      res += oStruct.fval2;
+    }
+    // This block post dominates the if block
+    oStruct = g_oStruct[doit];
+  }
+
+  return res;
+}

+ 51 - 0
tools/clang/test/HLSLFileCheck/passes/hl/sroa_hlsl/memcpy_split_replace.hlsl

@@ -0,0 +1,51 @@
+// RUN: %dxc -T ps_6_0 %s | FileCheck %s
+
+// A regression test for sub-memcpys
+// These might get eliminated if their attempted replacement removes them.
+// The goal is to create a memcpy that fails to RAUW the first two levels
+// but then finds itself able to for the last memcpy
+// If handled improperly, the result returned will be undefined and fail validation
+
+// CHECK: @main
+// CHECK-NOT: memcpy
+// CHECK: fmul
+// CHECK-NOT: float undef
+// CHECK: phi float
+// CHECK: ret void
+struct OuterStruct
+{
+  float fval;
+  struct InnerStruct {
+    float val1;
+    float val2;
+  } Array[3];
+};
+
+cbuffer cbuf : register(b1)
+{
+ OuterStruct g_oStruct[1];
+};
+
+float main(int doit : A) : SV_Target
+{
+  OuterStruct oStruct;
+
+  // Need a conditional so the memcpy source won't dominate the output
+  for (; doit >= 0; --doit) {
+    uint multiplier = 4;
+    // Because the struct is copied within conditional block, the source will be
+    // a GEP in the if statement which won't dominate the usage, thwarting the first RAUW replacement
+    // Copying twice thwarts the first RAUW
+    oStruct = g_oStruct[doit];
+    oStruct = g_oStruct[doit];
+    // Must use the struct twice to thwart the second RAUW replacement
+    // At this stage, trivial reusage is enough.
+    multiplier = oStruct.Array[0].val2 + oStruct.Array[0].val1;
+
+    // If memcpy is wrong, undef floats will be part of this calculation
+    oStruct.fval *= multiplier;
+  }
+
+  // This makes use of the memory that should be copied as part of the sub-memcpy
+  return oStruct.fval;
+}

+ 63 - 0
tools/clang/test/HLSLFileCheck/passes/hl/sroa_hlsl/memcpy_split_replace2.hlsl

@@ -0,0 +1,63 @@
+// RUN: %dxc -T ps_6_0 %s | FileCheck %s
+
+// A test for different lowermemcpy approaches at different levels.
+// The original memcpy is to copy the entire struct.
+// If only used once, the src location can replace the dest location
+// But by being used twice, this isn't possible.
+// The next memcpy which is meant to copy the array is similarly thwarted by double use
+// The last level, which is a memcpy for the struct array elements has only one usage
+// so it seems able to use RAUW. However, this memcpy came from the splitting of the
+// previous memcopies. The src and dst are GEPs fashioned expressly to copy just this
+// portion of the aggragate to satisfy whatever requirements might come later.
+// So src and dst don't seem to have any other users than the memcpy they were created for.
+// When that memcpy is deleted, thinking it has been replaced, these GEPs are left without
+// users and promptly removed. In fact, they do have users, but it is not apparent becuase
+// at this stage, the user only references the original complete aggregate alloca.
+// when mem2reg tries to convert the regions of memory that these would have populated
+// it determines they are empty and makes them undefined (0),
+// which ultimately removes much of the code.
+
+// Very similar to memcpy_split_replace, but this one produces validation errors when unfixed
+
+
+// CHECK: @main
+// CHECK-NOT: memcpy
+// CHECK: fmul
+// CHECK-NOT: float undef
+// CHECK: ret void
+struct OuterStruct
+{
+  float fval;
+  struct InnerStruct {
+    float val1;
+    float val2;
+  } Array[3];
+};
+
+cbuffer cbuf : register(b1)
+{
+ OuterStruct g_oStruct[4];
+};
+
+float main(int doit : A) : SV_Target
+{
+  OuterStruct oStruct;
+
+  {
+    float multiplier = 4.0;
+    // At this stage, even trivial double usage will thwart RAUW at the top level
+    // Neither conditional nor +1 are needed, but make a more credible usage pattern
+    oStruct = g_oStruct[doit];
+    if (doit)
+      oStruct = g_oStruct[doit+1];
+    // Double usage of the array element thwarts it at the second level too.
+    // However, each scalar member is used just once, allowing for RAUW
+    multiplier = oStruct.Array[1].val1 +
+      oStruct.Array[1].val2;
+
+    // If memcpy is wrong, undef floats will be part of this calculation
+    oStruct.fval *= multiplier;
+  }
+
+  return oStruct.fval;
+}