Browse Source

Lower memcpy before SORA a pointer. (#253)

1. Try to propagate if the pointer is only write once by the memcpy.
2. Use correct size for memcpy created in AppendBuffer::Append.
3. Ignore unused Constant user of Matrix value.
4. Replace Constant with Inst on the Instruction version of Constant.
5. When match size of memcpy, go deeper if struct only has 1 element.
   Doing this will leave memcpy be lowered at deeper level.
6. Do not replace when SimplePtrCopy, cannot decide replace is safe or not.
Xiang Li 8 years ago
parent
commit
9f437adda5

+ 5 - 0
lib/HLSL/HLMatrixLowerPass.cpp

@@ -2059,6 +2059,9 @@ void HLMatrixLowerPass::DeleteDeadInsts() {
 static bool OnlyUsedByMatrixLdSt(Value *V) {
   bool onlyLdSt = true;
   for (User *user : V->users()) {
+    if (isa<Constant>(user) && user->use_empty())
+      continue;
+
     CallInst *CI = cast<CallInst>(user);
     if (GetHLOpcodeGroupByName(CI->getCalledFunction()) ==
         HLOpcodeGroup::HLMatLoadStore)
@@ -2245,6 +2248,8 @@ void HLMatrixLowerPass::runOnGlobal(GlobalVariable *GV) {
       vecGlobals[i] = EltGV;
     }
     for (User *user : GV->users()) {
+      if (isa<Constant>(user) && user->use_empty())
+        continue;
       CallInst *CI = cast<CallInst>(user);
       TranslateMatLoadStoreOnGlobal(GV, vecGlobals, CI);
       AddToDeadInsts(CI);

+ 365 - 76
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -82,7 +82,12 @@ public:
                                   IRBuilder<> &Builder, bool bFlatVector,
                                   bool hasPrecise, DxilTypeSystem &typeSys,
                                   SmallVector<Value *, 32> &DeadInsts);
-  static void MarkEmptyStructUsers(Value *V, SmallVector<Value *, 32> &DeadInsts);
+  // Lower memcpy related to V.
+  static bool LowerMemcpy(Value *V, DxilFieldAnnotation *annotation,
+                          DxilTypeSystem &typeSys, const DataLayout &DL,
+                          bool bAllowReplace);
+  static void MarkEmptyStructUsers(Value *V,
+                                   SmallVector<Value *, 32> &DeadInsts);
   static bool IsEmptyStructType(Type *Ty, DxilTypeSystem &typeSys);
 private:
   SROA_Helper(Value *V, ArrayRef<Value *> Elts,
@@ -254,9 +259,12 @@ public:
   MemcpySplitter(llvm::LLVMContext &context, DxilTypeSystem &typeSys)
       : m_context(context), m_typeSys(typeSys) {}
   void Split(llvm::Function &F);
+
+  static void PatchMemCpyWithZeroIdxGEP(Module &M);
+  static void PatchMemCpyWithZeroIdxGEP(MemCpyInst *MI, const DataLayout &DL);
   static void SplitMemCpy(MemCpyInst *MI, const DataLayout &DL,
                           DxilFieldAnnotation *fieldAnnotation,
-                          DxilTypeSystem &typeSys, bool bAllowReplace);
+                          DxilTypeSystem &typeSys);
 };
 
 }
@@ -1057,11 +1065,11 @@ bool SROA_HLSL::runOnFunction(Function &F) {
   HLModule &HLM = M->GetOrCreateHLModule();
   DxilTypeSystem &typeSys = HLM.GetTypeSystem();
 
-  // change memcpy into ld/st first
+  bool Changed = performScalarRepl(F, typeSys);
+  // change rest memcpy into ld/st.
   MemcpySplitter splitter(F.getContext(), typeSys);
   splitter.Split(F);
 
-  bool Changed = performScalarRepl(F, typeSys);
   Changed |= markPrecise(F);
   Changed |= performPromotion(F);
 
@@ -1548,6 +1556,12 @@ bool SROA_HLSL::performScalarRepl(Function &F, DxilTypeSystem &typeSys) {
         Changed = true;
         continue;
       }
+      const bool bAllowReplace = true;
+      if (SROA_Helper::LowerMemcpy(AI, /*annotation*/ nullptr, typeSys, DL,
+                                   bAllowReplace)) {
+        Changed = true;
+        continue;
+      }
 
       // If this alloca is impossible for us to promote, reject it early.
       if (AI->isArrayAllocation() || !AI->getAllocatedType()->isSized())
@@ -2121,23 +2135,14 @@ bool SROA_HLSL::isSafeAllocaToScalarRepl(AllocaInst *AI) {
 
 // Copy data from srcPtr to destPtr.
 static void SimplePtrCopy(Value *DestPtr, Value *SrcPtr,
-                       llvm::SmallVector<llvm::Value *, 16> &idxList,
-                       bool bAllowReplace,
-                       IRBuilder<> &Builder) {
-  // If src only has one use, just replace Dest with Src.
-  if (bAllowReplace && SrcPtr->hasOneUse() &&
-        // Only 2 uses for dest: 1 for memcpy, 1 for other use.
-      !DestPtr->hasNUsesOrMore(3) &&
-      !isa<CallInst>(SrcPtr) && !isa<CallInst>(DestPtr)) {
-    DestPtr->replaceAllUsesWith(SrcPtr);
-  } else {
-    if (idxList.size() > 1) {
-      DestPtr = Builder.CreateInBoundsGEP(DestPtr, idxList);
-      SrcPtr = Builder.CreateInBoundsGEP(SrcPtr, idxList);
-    }
-    llvm::LoadInst *ld = Builder.CreateLoad(SrcPtr);
-    Builder.CreateStore(ld, DestPtr);
-  }
+                          llvm::SmallVector<llvm::Value *, 16> &idxList,
+                          IRBuilder<> &Builder) {
+  if (idxList.size() > 1) {
+    DestPtr = Builder.CreateInBoundsGEP(DestPtr, idxList);
+    SrcPtr = Builder.CreateInBoundsGEP(SrcPtr, idxList);
+  }
+  llvm::LoadInst *ld = Builder.CreateLoad(SrcPtr);
+  Builder.CreateStore(ld, DestPtr);
 }
 
 // Copy srcVal to destPtr.
@@ -2160,36 +2165,39 @@ static void SimpleValCopy(Value *DestPtr, Value *SrcVal,
 
 static void SimpleCopy(Value *Dest, Value *Src,
                        llvm::SmallVector<llvm::Value *, 16> &idxList,
-                       bool bAllowReplace,
                        IRBuilder<> &Builder) {
   if (Src->getType()->isPointerTy())
-    SimplePtrCopy(Dest, Src, idxList, bAllowReplace, Builder);
+    SimplePtrCopy(Dest, Src, idxList, Builder);
   else
     SimpleValCopy(Dest, Src, idxList, Builder);
 }
 // Split copy into ld/st.
 static void SplitCpy(Type *Ty, Value *Dest, Value *Src,
-                     SmallVector<Value *, 16> &idxList, bool bAllowReplace,
-                     IRBuilder<> &Builder, DxilTypeSystem &typeSys,
+                     SmallVector<Value *, 16> &idxList, IRBuilder<> &Builder,
+                     DxilTypeSystem &typeSys,
                      DxilFieldAnnotation *fieldAnnotation) {
   if (PointerType *PT = dyn_cast<PointerType>(Ty)) {
     Constant *idx = Constant::getIntegerValue(
         IntegerType::get(Ty->getContext(), 32), APInt(32, 0));
     idxList.emplace_back(idx);
 
-    SplitCpy(PT->getElementType(), Dest, Src, idxList, bAllowReplace, Builder,
-             typeSys, fieldAnnotation);
+    SplitCpy(PT->getElementType(), Dest, Src, idxList, Builder, typeSys,
+             fieldAnnotation);
 
     idxList.pop_back();
   } else if (HLMatrixLower::IsMatrixType(Ty)) {
-    DXASSERT(fieldAnnotation, "require fieldAnnotation here");
-    DXASSERT(fieldAnnotation->HasMatrixAnnotation(),
-             "must has matrix annotation");
+    // If no fieldAnnotation, use row major as default.
+    // Only load then store immediately should be fine.
+    bool bRowMajor = true;
+    if (fieldAnnotation) {
+      DXASSERT(fieldAnnotation->HasMatrixAnnotation(),
+               "must has matrix annotation");
+      bRowMajor = fieldAnnotation->GetMatrixAnnotation().Orientation ==
+                  MatrixOrientation::RowMajor;
+    }
     Module *M = Builder.GetInsertPoint()->getModule();
     Value *DestGEP = Builder.CreateInBoundsGEP(Dest, idxList);
     Value *SrcGEP = Builder.CreateInBoundsGEP(Src, idxList);
-    bool bRowMajor = fieldAnnotation->GetMatrixAnnotation().Orientation ==
-                     MatrixOrientation::RowMajor;
     if (bRowMajor) {
       Value *Load = HLModule::EmitHLOperationCall(
           Builder, HLOpcodeGroup::HLMatLoadStore,
@@ -2216,7 +2224,7 @@ static void SplitCpy(Type *Ty, Value *Dest, Value *Src,
   } else if (StructType *ST = dyn_cast<StructType>(Ty)) {
     if (HLModule::IsHLSLObjectType(ST)) {
       // Avoid split HLSL object.
-      SimpleCopy(Dest, Src, idxList, bAllowReplace, Builder);
+      SimpleCopy(Dest, Src, idxList, Builder);
       return;
     }
     DxilStructAnnotation *STA = typeSys.GetStructAnnotation(ST);
@@ -2228,8 +2236,7 @@ static void SplitCpy(Type *Ty, Value *Dest, Value *Src,
           IntegerType::get(Ty->getContext(), 32), APInt(32, i));
       idxList.emplace_back(idx);
       DxilFieldAnnotation &EltAnnotation = STA->GetFieldAnnotation(i);
-      SplitCpy(ET, Dest, Src, idxList, bAllowReplace, Builder, typeSys,
-               &EltAnnotation);
+      SplitCpy(ET, Dest, Src, idxList, Builder, typeSys, &EltAnnotation);
 
       idxList.pop_back();
     }
@@ -2241,13 +2248,12 @@ static void SplitCpy(Type *Ty, Value *Dest, Value *Src,
       Constant *idx = Constant::getIntegerValue(
           IntegerType::get(Ty->getContext(), 32), APInt(32, i));
       idxList.emplace_back(idx);
-      SplitCpy(ET, Dest, Src, idxList, bAllowReplace, Builder, typeSys,
-               fieldAnnotation);
+      SplitCpy(ET, Dest, Src, idxList, Builder, typeSys, fieldAnnotation);
 
       idxList.pop_back();
     }
   } else {
-    SimpleCopy(Dest, Src, idxList, bAllowReplace, Builder);
+    SimpleCopy(Dest, Src, idxList, Builder);
   }
 }
 
@@ -2319,8 +2325,18 @@ static void SplitPtr(Type *Ty, Value *Ptr, SmallVector<Value *, 16> &idxList,
 static unsigned MatchSizeByCheckElementType(Type *Ty, const DataLayout &DL, unsigned size, unsigned level) {
   unsigned ptrSize = DL.getTypeAllocSize(Ty);
   // Size match, return current level.
-  if (ptrSize == size)
+  if (ptrSize == size) {
+    // For struct, go deeper if size not change.
+    // This will leave memcpy to deeper level when flatten.
+    if (StructType *ST = dyn_cast<StructType>(Ty)) {
+      if (ST->getNumElements() == 1) {
+        return MatchSizeByCheckElementType(ST->getElementType(0), DL, size, level+1);
+      }
+    }
+    // Don't do this for array.
+    // Array will be flattened as struct of array.
     return level;
+  }
   // Add ZeroIdx cannot make ptrSize bigger.
   if (ptrSize < size)
     return 0;
@@ -2335,12 +2351,26 @@ static unsigned MatchSizeByCheckElementType(Type *Ty, const DataLayout &DL, unsi
   }
 }
 
-void MemcpySplitter::SplitMemCpy(MemCpyInst *MI, const DataLayout &DL,
-                                 DxilFieldAnnotation *fieldAnnotation,
-                                 DxilTypeSystem &typeSys, bool bAllowReplace) {
-  Value *Op0 = MI->getOperand(0);
-  Value *Op1 = MI->getOperand(1);
+static void PatchZeroIdxGEP(Value *Ptr, Value *RawPtr, MemCpyInst *MI,
+                            unsigned level, IRBuilder<> &Builder) {
+  Value *zeroIdx = Builder.getInt32(0);
+  SmallVector<Value *, 2> IdxList(level + 1, zeroIdx);
+  Value *GEP = Builder.CreateInBoundsGEP(Ptr, IdxList);
+  // Use BitCastInst::Create to prevent idxList from being optimized.
+  CastInst *Cast =
+      BitCastInst::Create(Instruction::BitCast, GEP, RawPtr->getType());
+  Builder.Insert(Cast);
+  MI->replaceUsesOfWith(RawPtr, Cast);
+  // Remove RawPtr if possible.
+  if (RawPtr->user_empty()) {
+    if (Instruction *I = dyn_cast<Instruction>(RawPtr)) {
+      I->eraseFromParent();
+    }
+  }
+}
 
+void MemcpySplitter::PatchMemCpyWithZeroIdxGEP(MemCpyInst *MI,
+                                               const DataLayout &DL) {
   Value *Dest = MI->getRawDest();
   Value *Src = MI->getRawSource();
   // Only remove one level bitcast generated from inline.
@@ -2356,17 +2386,47 @@ void MemcpySplitter::SplitMemCpy(MemCpyInst *MI, const DataLayout &DL,
   // bitcast ptr.
   ConstantInt *Length = cast<ConstantInt>(MI->getLength());
   unsigned size = Length->getLimitedValue();
-  Value *zeroIdx = Builder.getInt32(0);
   if (unsigned level = MatchSizeByCheckElementType(DestTy, DL, size, 0)) {
-    SmallVector<Value *, 2> IdxList(level + 1, zeroIdx);
-    Dest = Builder.CreateInBoundsGEP(Dest, IdxList);
-    DestTy = Dest->getType()->getPointerElementType();
+    PatchZeroIdxGEP(Dest, MI->getRawDest(), MI, level, Builder);
   }
   if (unsigned level = MatchSizeByCheckElementType(SrcTy, DL, size, 0)) {
-    SmallVector<Value *, 2> IdxList(level + 1, zeroIdx);
-    Src = Builder.CreateInBoundsGEP(Src, IdxList);
-    SrcTy = Src->getType()->getPointerElementType();
+    PatchZeroIdxGEP(Src, MI->getRawSource(), MI, level, Builder);
+  }
+}
+
+void MemcpySplitter::PatchMemCpyWithZeroIdxGEP(Module &M) {
+  const DataLayout &DL = M.getDataLayout();
+  for (Function &F : M.functions()) {
+    for (Function::iterator BB = F.begin(), BBE = F.end(); BB != BBE; ++BB) {
+      for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE;) {
+        // Avoid invalidating the iterator.
+        Instruction *I = BI++;
+
+        if (MemCpyInst *MI = dyn_cast<MemCpyInst>(I)) {
+          PatchMemCpyWithZeroIdxGEP(MI, DL);
+        }
+      }
+    }
   }
+}
+
+void MemcpySplitter::SplitMemCpy(MemCpyInst *MI, const DataLayout &DL,
+                                 DxilFieldAnnotation *fieldAnnotation,
+                                 DxilTypeSystem &typeSys) {
+  Value *Op0 = MI->getOperand(0);
+  Value *Op1 = MI->getOperand(1);
+
+  Value *Dest = MI->getRawDest();
+  Value *Src = MI->getRawSource();
+  // Only remove one level bitcast generated from inline.
+  if (BitCastOperator *BC = dyn_cast<BitCastOperator>(Dest))
+    Dest = BC->getOperand(0);
+  if (BitCastOperator *BC = dyn_cast<BitCastOperator>(Src))
+    Src = BC->getOperand(0);
+
+  IRBuilder<> Builder(MI);
+  Type *DestTy = Dest->getType()->getPointerElementType();
+  Type *SrcTy = Src->getType()->getPointerElementType();
 
   // Allow copy between different address space.
   if (DestTy != SrcTy) {
@@ -2377,7 +2437,7 @@ void MemcpySplitter::SplitMemCpy(MemCpyInst *MI, const DataLayout &DL,
   // split
   // Matrix is treated as scalar type, will not use memcpy.
   // So use nullptr for fieldAnnotation should be safe here.
-  SplitCpy(Dest->getType(), Dest, Src, idxList, bAllowReplace, Builder, typeSys,
+  SplitCpy(Dest->getType(), Dest, Src, idxList, Builder, typeSys,
            fieldAnnotation);
   // delete memcpy
   MI->eraseFromParent();
@@ -2402,7 +2462,7 @@ void MemcpySplitter::Split(llvm::Function &F) {
       if (MemCpyInst *MI = dyn_cast<MemCpyInst>(I)) {
         // Matrix is treated as scalar type, will not use memcpy.
         // So use nullptr for fieldAnnotation should be safe here.
-        SplitMemCpy(MI, DL, /*fieldAnnotation*/ nullptr, m_typeSys, /*bAllowReplace*/ false);
+        SplitMemCpy(MI, DL, /*fieldAnnotation*/ nullptr, m_typeSys);
       }
     }
   }
@@ -3376,6 +3436,231 @@ bool SROA_Helper::DoScalarReplacement(GlobalVariable *GV, std::vector<Value *> &
   return true;
 }
 
+struct PointerStatus {
+  /// Keep track of what stores to the pointer look like.
+  enum StoredType {
+    /// There is no store to this global.  It can thus be marked constant.
+    NotStored,
+
+    /// This ptr is a global, and is stored to, but the only thing stored is the
+    /// constant it
+    /// was initialized with. This is only tracked for scalar globals.
+    InitializerStored,
+
+    /// This ptr is stored to, but only its initializer and one other value
+    /// is ever stored to it.  If this global isStoredOnce, we track the value
+    /// stored to it in StoredOnceValue below.  This is only tracked for scalar
+    /// globals.
+    StoredOnce,
+
+    /// This ptr is only assigned by a memcpy.
+    MemcopyDestOnce,
+
+    /// This ptr is stored to by multiple values or something else that we
+    /// cannot track.
+    Stored
+  } StoredType;
+
+  /// If only one value (besides the initializer constant) is ever stored to
+  /// this global, keep track of what value it is.
+  Value *StoredOnceValue;
+  /// Memcpy which this ptr is used.
+  std::vector<MemCpyInst *> memcpyList;
+  /// Memcpy which use this ptr as dest.
+  MemCpyInst *StoringMemcpy;
+  /// These start out null/false.  When the first accessing function is noticed,
+  /// it is recorded. When a second different accessing function is noticed,
+  /// HasMultipleAccessingFunctions is set to true.
+  const Function *AccessingFunction;
+  bool HasMultipleAccessingFunctions;
+  /// Size of the ptr.
+  unsigned Size;
+
+  /// Look at all uses of the global and fill in the GlobalStatus structure.  If
+  /// the global has its address taken, return true to indicate we can't do
+  /// anything with it.
+  static void analyzePointer(const Value *V, PointerStatus &PS,
+                             DxilTypeSystem &typeSys, bool bStructElt);
+
+  PointerStatus(unsigned size)
+      : StoredType(NotStored), StoredOnceValue(nullptr), StoringMemcpy(nullptr),
+        AccessingFunction(nullptr), HasMultipleAccessingFunctions(false),
+        Size(size) {}
+  void MarkAsStored() {
+    StoredType = PointerStatus::StoredType::Stored;
+    StoredOnceValue = nullptr;
+  }
+};
+
+void PointerStatus::analyzePointer(const Value *V, PointerStatus &PS,
+                                   DxilTypeSystem &typeSys, bool bStructElt) {
+  if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) {
+    if (GV->hasInitializer() && !isa<UndefValue>(GV->getInitializer())) {
+      PS.StoredType = PointerStatus::StoredType::InitializerStored;
+    }
+  }
+
+  for (const User *U : V->users()) {
+    if (const Instruction *I = dyn_cast<Instruction>(U)) {
+      const Function *F = I->getParent()->getParent();
+      if (!PS.AccessingFunction) {
+        F = PS.AccessingFunction;
+      } else {
+        if (F != PS.AccessingFunction)
+          PS.HasMultipleAccessingFunctions = true;
+      }
+    }
+
+    if (const BitCastOperator *BC = dyn_cast<BitCastOperator>(U)) {
+      analyzePointer(BC, PS, typeSys, bStructElt);
+    } else if (const MemCpyInst *MC = dyn_cast<MemCpyInst>(U)) {
+      // Do not collect memcpy on struct GEP use.
+      // These memcpy will be flattened in next level.
+      if (!bStructElt) {
+        PS.memcpyList.emplace_back(const_cast<MemCpyInst *>(MC));
+        bool bFullCopy = false;
+        if (ConstantInt *Length = dyn_cast<ConstantInt>(MC->getLength())) {
+          bFullCopy = PS.Size == Length->getLimitedValue();
+        }
+        if (MC->getRawDest() == V) {
+          if (bFullCopy &&
+              PS.StoredType == PointerStatus::StoredType::NotStored) {
+            PS.StoredType = PointerStatus::StoredType::MemcopyDestOnce;
+            PS.StoringMemcpy = PS.memcpyList.back();
+          } else {
+            PS.MarkAsStored();
+            PS.StoringMemcpy = nullptr;
+          }
+        }
+      } else {
+        PS.MarkAsStored();
+      }
+    } else if (const GEPOperator *GEP = dyn_cast<GEPOperator>(U)) {
+      gep_type_iterator GEPIt = gep_type_begin(GEP);
+      gep_type_iterator GEPEnd = gep_type_end(GEP);
+      // Skip pointer idx.
+      GEPIt++;
+      // Struct elt will be flattened in next level.
+      bool bStructElt = (GEPIt != GEPEnd) && GEPIt->isStructTy();
+      analyzePointer(GEP, PS, typeSys, bStructElt);
+    } else if (const StoreInst *SI = dyn_cast<StoreInst>(U)) {
+      Value *V = SI->getOperand(0);
+
+      if (PS.StoredType == PointerStatus::StoredType::NotStored) {
+        PS.StoredType = PointerStatus::StoredType::StoredOnce;
+        PS.StoredOnceValue = V;
+      } else {
+        PS.MarkAsStored();
+      }
+    } else if (const CallInst *CI = dyn_cast<CallInst>(U)) {
+      Function *F = CI->getCalledFunction();
+      DxilFunctionAnnotation *annotation = typeSys.GetFunctionAnnotation(F);
+      if (!annotation) {
+        // If not sure its out param or not. Take as out param.
+        PS.MarkAsStored();
+        continue;
+      }
+
+      unsigned argSize = F->arg_size();
+      for (unsigned i = 0; i < argSize; i++) {
+        Value *arg = CI->getArgOperand(i);
+        if (V == arg) {
+          DxilParamInputQual inputQual =
+              annotation->GetParameterAnnotation(i).GetParamInputQual();
+          if (inputQual != DxilParamInputQual::In) {
+            PS.MarkAsStored();
+            break;
+          }
+        }
+      }
+    }
+  }
+}
+
+static void ReplaceConstantWithInst(Constant *C, Value *V, IRBuilder<> &Builder) {
+  for (auto it = C->user_begin(); it != C->user_end(); ) {
+    User *U = *(it++);
+    if (Instruction *I = dyn_cast<Instruction>(U)) {
+      I->replaceUsesOfWith(C, V);
+    } else {
+      ConstantExpr *CE = cast<ConstantExpr>(U);
+      Instruction *Inst = CE->getAsInstruction();
+      Builder.Insert(Inst);
+      Inst->replaceUsesOfWith(C, V);
+      ReplaceConstantWithInst(CE, Inst, Builder);
+    }
+  }
+}
+
+static void ReplaceMemcpy(Value *V, Value *Src, MemCpyInst *MC) {
+  if (Constant *C = dyn_cast<Constant>(V)) {
+    if (isa<Constant>(Src)) {
+      V->replaceAllUsesWith(Src);
+    } else {
+      // Replace Constant with a non-Constant.
+      IRBuilder<> Builder(MC);
+      ReplaceConstantWithInst(C, Src, Builder);
+    }
+  } else {
+    V->replaceAllUsesWith(Src);
+  }
+  Value *RawDest = MC->getOperand(0);
+  Value *RawSrc = MC->getOperand(1);
+  MC->eraseFromParent();
+  if (Instruction *I = dyn_cast<Instruction>(RawDest)) {
+    if (I->user_empty())
+      I->eraseFromParent();
+  }
+  if (Instruction *I = dyn_cast<Instruction>(RawSrc)) {
+    if (I->user_empty())
+      I->eraseFromParent();
+  }
+}
+
+bool SROA_Helper::LowerMemcpy(Value *V, DxilFieldAnnotation *annotation,
+                              DxilTypeSystem &typeSys, const DataLayout &DL,
+                              bool bAllowReplace) {
+  Type *Ty = V->getType();
+  if (!Ty->isPointerTy()) {
+    return false;
+  }
+  // Get access status and collect memcpy uses.
+  // if MemcpyOnce, replace with dest with src if dest is not out param.
+  // else flat memcpy.
+  unsigned size = DL.getTypeAllocSize(Ty->getPointerElementType());
+  PointerStatus PS(size);
+  const bool bStructElt = false;
+  PointerStatus::analyzePointer(V, PS, typeSys, bStructElt);
+  if (bAllowReplace &&
+      PS.StoredType == PointerStatus::StoredType::MemcopyDestOnce &&
+      !PS.HasMultipleAccessingFunctions) {
+    // How to make sure Src is not updated after Memcopy?
+
+    // Replace with src of memcpy.
+    MemCpyInst *MC = PS.StoringMemcpy;
+    if (MC->getSourceAddressSpace() == MC->getDestAddressSpace()) {
+      Value *Src = MC->getOperand(1);
+      // Only remove one level bitcast generated from inline.
+      if (BitCastOperator *BC = dyn_cast<BitCastOperator>(Src))
+        Src = BC->getOperand(0);
+
+      // Need to make sure src not updated after current memcpy.
+      // Check Src only have 1 store now.
+      PointerStatus SrcPS(size);
+      PointerStatus::analyzePointer(Src, SrcPS, typeSys, bStructElt);
+      if (SrcPS.StoredType != PointerStatus::StoredType::Stored) {
+        ReplaceMemcpy(V, Src, MC);
+        return true;
+      }
+    }
+  }
+
+  for (MemCpyInst *MC : PS.memcpyList) {
+    MemcpySplitter::SplitMemCpy(MC, DL, annotation, typeSys);
+  }
+  return false;
+}
+
 /// MarkEmptyStructUsers - Add instruction related to Empty struct to DeadInsts.
 void SROA_Helper::MarkEmptyStructUsers(Value *V, SmallVector<Value *, 32> &DeadInsts) {
   for (User *U : V->users()) {
@@ -3416,7 +3701,11 @@ public:
   explicit SROA_Parameter_HLSL() : ModulePass(ID) {}
   const char *getPassName() const override { return "SROA Parameter HLSL"; }
 
-  bool runOnModule(Module &M) override { 
+  bool runOnModule(Module &M) override {
+    // Patch memcpy to cover case bitcast (gep ptr, 0,0) is transformed into
+    // bitcast ptr.
+    MemcpySplitter::PatchMemCpyWithZeroIdxGEP(M);
+
     m_pHLModule = &M.GetOrCreateHLModule();
 
     // Load up debug information, to cross-reference values and the instructions
@@ -3642,7 +3931,6 @@ void SROA_Parameter_HLSL::flattenGlobal(GlobalVariable *GV) {
   const DataLayout &DL = GV->getParent()->getDataLayout();
   unsigned debugOffset = 0;
   std::unordered_map<Value*, StringRef> EltNameMap;
-  bool isFlattened = false;
   // Process the worklist
   while (!WorkList.empty()) {
     GlobalVariable *EltGV = cast<GlobalVariable>(WorkList.front());
@@ -3650,6 +3938,12 @@ void SROA_Parameter_HLSL::flattenGlobal(GlobalVariable *GV) {
     // Flat Global vector if no dynamic vector indexing.
     bool bFlatVector = !hasDynamicVectorIndexing(EltGV);
 
+    const bool bAllowReplace = true;
+    if (SROA_Helper::LowerMemcpy(EltGV, /*annoation*/ nullptr, dxilTypeSys, DL,
+                                 bAllowReplace)) {
+      continue;
+    }
+
     std::vector<Value *> Elts;
     bool SROAed = SROA_Helper::DoScalarReplacement(
         EltGV, Elts, Builder, bFlatVector,
@@ -3684,14 +3978,12 @@ void SROA_Parameter_HLSL::flattenGlobal(GlobalVariable *GV) {
             EltNameMap[EltGV]);
         debugOffset += size;
       }
-      if (GV != EltGV)
-        isFlattened = true;
     }
   }
 
   DeleteDeadInstructions();
 
-  if (isFlattened) {
+  if (GV->user_empty()) {
     GV->removeDeadConstantUsers();
     GV->eraseFromParent();
   }
@@ -4592,6 +4884,8 @@ void SROA_Parameter_HLSL::flattenArgument(
     // Do not skip unused parameter.
 
     DxilFieldAnnotation &annotation = annotationMap[V];
+    const bool bAllowReplace = !bOut;
+    SROA_Helper::LowerMemcpy(V, &annotation, dxilTypeSys, DL, bAllowReplace);
 
     std::vector<Value *> Elts;
     // Not flat vector for entry function currently.
@@ -4798,9 +5092,8 @@ void SROA_Parameter_HLSL::flattenArgument(
                 IRBuilder<> Builder(CI);
 
                 llvm::SmallVector<llvm::Value *, 16> idxList;
-                SplitCpy(data->getType(), outputVal, data, idxList,
-                         /*bAllowReplace*/ false, Builder, dxilTypeSys,
-                         &flatParamAnnotation);
+                SplitCpy(data->getType(), outputVal, data, idxList, Builder,
+                         dxilTypeSys, &flatParamAnnotation);
 
                 CI->setArgOperand(HLOperandIndex::kStreamAppendDataOpIndex, outputVal);
               }
@@ -4826,8 +5119,7 @@ void SROA_Parameter_HLSL::flattenArgument(
 
                   llvm::SmallVector<llvm::Value *, 16> idxList;
                   SplitCpy(DataPtr->getType(), EltPtr, DataPtr, idxList,
-                           /*bAllowReplace*/ false, Builder, dxilTypeSys,
-                           &flatParamAnnotation);
+                           Builder, dxilTypeSys, &flatParamAnnotation);
                   CI->setArgOperand(i, EltPtr);
                 }
               }
@@ -4899,8 +5191,8 @@ static void SplitArrayCopy(Value *V, DxilTypeSystem &typeSys,
       Value *val = ST->getValueOperand();
       IRBuilder<> Builder(ST);
       SmallVector<Value *, 16> idxList;
-      SplitCpy(ptr->getType(), ptr, val, idxList, /*bAllowReplace*/ true,
-               Builder, typeSys, fieldAnnotation);
+      SplitCpy(ptr->getType(), ptr, val, idxList, Builder, typeSys,
+               fieldAnnotation);
       ST->eraseFromParent();
     }
   }
@@ -5038,8 +5330,8 @@ static void LegalizeDxilInputOutputs(Function *F,
       if (bStoreInputToTemp) {
         llvm::SmallVector<llvm::Value *, 16> idxList;
         // split copy.
-        SplitCpy(temp->getType(), temp, &arg, idxList, /*bAllowReplace*/ false,
-                 Builder, typeSys, &paramAnnotation);
+        SplitCpy(temp->getType(), temp, &arg, idxList, Builder, typeSys,
+                 &paramAnnotation);
       }
 
       // Generate store output, temp later.
@@ -5068,8 +5360,8 @@ static void LegalizeDxilInputOutputs(Function *F,
         else
           onlyRetBlk = true;
         // split copy.
-        SplitCpy(output->getType(), output, temp, idxList,
-                 /*bAllowReplace*/ false, Builder, typeSys, &paramAnnotation);
+        SplitCpy(output->getType(), output, temp, idxList, Builder, typeSys,
+                 &paramAnnotation);
       }
       // Clone the return.
       Builder.CreateRet(RI->getReturnValue());
@@ -5080,9 +5372,6 @@ static void LegalizeDxilInputOutputs(Function *F,
 
 void SROA_Parameter_HLSL::createFlattenedFunction(Function *F) {
   DxilTypeSystem &typeSys = m_pHLModule->GetTypeSystem();
-  // Change memcpy into ld/st first
-  MemcpySplitter splitter(F->getContext(), typeSys);
-  splitter.Split(*F);
 
   // Skip void (void) function.
   if (F->getReturnType()->isVoidTy() && F->getArgumentList().empty()) {
@@ -5419,8 +5708,8 @@ void SROA_Parameter_HLSL::createFlattenedFunctionCall(Function *F, Function *fla
         // Copy in param.
         llvm::SmallVector<llvm::Value *, 16> idxList;
         // split copy to avoid load of struct.
-        SplitCpy(Ty, tempArg, arg, idxList, /*bAllowReplace*/ false, CallBuilder,
-                 typeSys, &paramAnnotation);
+        SplitCpy(Ty, tempArg, arg, idxList, CallBuilder, typeSys,
+                 &paramAnnotation);
       }
 
       if (inputQual == DxilParamInputQual::Out ||
@@ -5428,8 +5717,8 @@ void SROA_Parameter_HLSL::createFlattenedFunctionCall(Function *F, Function *fla
         // Copy out param.
         llvm::SmallVector<llvm::Value *, 16> idxList;
         // split copy to avoid load of struct.
-        SplitCpy(Ty, arg, tempArg, idxList, /*bAllowReplace*/ false, RetBuilder,
-                 typeSys, &paramAnnotation);
+        SplitCpy(Ty, arg, tempArg, idxList, RetBuilder, typeSys,
+                 &paramAnnotation);
       }
       arg = tempArg;
       flattenArgument(flatF, arg, paramAnnotation, FlatParamList,

+ 2 - 6
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -2973,12 +2973,8 @@ static Function *CreateOpFunction(llvm::Module &M, Function *F,
         Argument *valArg = argIter;
         // Buf[counter] = val;
         if (valTy->isPointerTy()) {
-          Value *valArgCast = Builder.CreateBitCast(valArg, llvm::Type::getInt8PtrTy(F->getContext()));
-          Value *subscriptCast = Builder.CreateBitCast(subscript, llvm::Type::getInt8PtrTy(F->getContext()));
-          // TODO: use real type size and alignment.
-          Value *tySize = ConstantInt::get(idxTy, 8);
-          unsigned Align = 8;
-          Builder.CreateMemCpy(subscriptCast, valArgCast, tySize, Align);
+          unsigned size = M.getDataLayout().getTypeAllocSize(subscript->getType()->getPointerElementType());
+          Builder.CreateMemCpy(subscript, valArg, size, 1);
         } else
           Builder.CreateStore(valArg, subscript);
         Builder.CreateRetVoid();

+ 22 - 0
tools/clang/test/CodeGenHLSL/cbuffer_copy.hlsl

@@ -0,0 +1,22 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// Make sure no alloca to copy.
+// CHECK-NOT: alloca
+
+cbuffer T
+{
+	float4 a[2];
+	float4 b[2];
+}
+static const struct
+{
+	float4 a[2];
+	float4 b[2];
+} ST = { a, b};
+
+uint i;
+
+float4 main() : SV_Target
+{
+  return ST.a[i] + ST.b[i];
+}

+ 22 - 0
tools/clang/test/CodeGenHLSL/cbuffer_copy2.hlsl

@@ -0,0 +1,22 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// Make sure no alloca to copy.
+// CHECK-NOT: alloca
+
+cbuffer T
+{
+	float4 a[1];
+	float4 b[2];
+}
+static const struct
+{
+	float4 a[1];
+	float4 b[2];
+} ST = { a, b};
+
+uint i;
+
+float4 main() : SV_Target
+{
+  return ST.a[i] + ST.b[i];
+}

+ 28 - 0
tools/clang/test/CodeGenHLSL/cbuffer_copy3.hlsl

@@ -0,0 +1,28 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// Make sure no alloca to copy.
+// CHECK-NOT: alloca
+
+struct S {
+  float s;
+};
+
+cbuffer T
+{
+        S  s;
+	float4 a[1];
+	float4 b[2];
+}
+static const struct
+{
+        S  s;
+	float4 a[1];
+	float4 b[2];
+} ST = { s, a, b};
+
+uint i;
+
+float4 main() : SV_Target
+{
+  return ST.a[i] + ST.b[i];
+}

+ 15 - 0
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -330,6 +330,9 @@ public:
   TEST_METHOD(CodeGenCast5)
   TEST_METHOD(CodeGenCast6)
   TEST_METHOD(CodeGenCbuf_init_static)
+  TEST_METHOD(CodeGenCbufferCopy)
+  TEST_METHOD(CodeGenCbufferCopy2)
+  TEST_METHOD(CodeGenCbufferCopy3)
   TEST_METHOD(CodeGenCbuffer_unused)
   TEST_METHOD(CodeGenCbuffer1_50)
   TEST_METHOD(CodeGenCbuffer1_51)
@@ -2232,6 +2235,18 @@ TEST_F(CompilerTest, CodeGenCbuf_init_static) {
   CodeGenTest(L"..\\CodeGenHLSL\\cbuf_init_static.hlsl");
 }
 
+TEST_F(CompilerTest, CodeGenCbufferCopy) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\cbuffer_copy.hlsl");
+}
+
+TEST_F(CompilerTest, CodeGenCbufferCopy2) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\cbuffer_copy2.hlsl");
+}
+
+TEST_F(CompilerTest, CodeGenCbufferCopy3) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\cbuffer_copy3.hlsl");
+}
+
 TEST_F(CompilerTest, CodeGenCbuffer_unused) {
   CodeGenTest(L"..\\CodeGenHLSL\\cbuffer_unused.hlsl");
 }