浏览代码

Better memcpy propagation. (#1233)

1. MergeGepUse for cbuffer and temp gep created when flatten.
2. When flatten alloca, always keep order by size of type.
   So we can make sure big alloca flattened first, then small alloca can easier remove memcpy.
3. For global variable which has zero init value, replace use of zero initializer if possible.
4. In ResourceToHandle::ReplaceResourceWithHandle only assert when has user.
5. MergeGepUse for input output in collectInputOutputAccessInfo.
Xiang Li 7 年之前
父节点
当前提交
e1fd0fc9de

+ 1 - 1
lib/HLSL/HLOperationLower.cpp

@@ -1264,7 +1264,7 @@ Value *TranslateAtan2(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode,
   Constant *halfPi = ConstantFP::get(Ty->getScalarType(), M_PI / 2);
   Constant *negHalfPi = ConstantFP::get(Ty->getScalarType(), -M_PI / 2);
   Constant *zero = ConstantFP::get(Ty->getScalarType(), 0);
-  if (Ty != Ty->getScalarType()) {
+  if (Ty->isVectorTy()) {
     unsigned vecSize = Ty->getVectorNumElements();
     pi = ConstantVector::getSplat(vecSize, pi);
     halfPi = ConstantVector::getSplat(vecSize, halfPi);

+ 3 - 3
lib/HLSL/HLSignatureLower.cpp

@@ -701,9 +701,9 @@ void collectInputOutputAccessInfo(
     Value *GV, Constant *constZero,
     std::vector<InputOutputAccessInfo> &accessInfoList, bool hasVertexID,
     bool bInput, bool bRowMajor) {
-  auto User = GV->user_begin();
-  auto UserE = GV->user_end();
-  for (; User != UserE;) {
+  // merge GEP use for input output.
+  HLModule::MergeGepUse(GV);
+  for (auto User = GV->user_begin(); User != GV->user_end();) {
     Value *I = *(User++);
     if (LoadInst *ldInst = dyn_cast<LoadInst>(I)) {
       if (bInput) {

+ 353 - 158
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -20,6 +20,7 @@
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/Loads.h"
 #include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Analysis/PostDominators.h"
 #include "llvm/IR/CallSite.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DIBuilder.h"
@@ -58,6 +59,7 @@
 #include <deque>
 #include <unordered_map>
 #include <unordered_set>
+#include <queue>
 
 using namespace llvm;
 using namespace hlsl;
@@ -77,11 +79,13 @@ public:
   static bool DoScalarReplacement(Value *V, std::vector<Value *> &Elts,
                                   IRBuilder<> &Builder, bool bFlatVector,
                                   bool hasPrecise, DxilTypeSystem &typeSys,
+                                  const DataLayout &DL,
                                   SmallVector<Value *, 32> &DeadInsts);
 
   static bool DoScalarReplacement(GlobalVariable *GV, std::vector<Value *> &Elts,
                                   IRBuilder<> &Builder, bool bFlatVector,
                                   bool hasPrecise, DxilTypeSystem &typeSys,
+                                  const DataLayout &DL,
                                   SmallVector<Value *, 32> &DeadInsts);
   // Lower memcpy related to V.
   static bool LowerMemcpy(Value *V, DxilFieldAnnotation *annotation,
@@ -92,8 +96,9 @@ public:
   static bool IsEmptyStructType(Type *Ty, DxilTypeSystem &typeSys);
 private:
   SROA_Helper(Value *V, ArrayRef<Value *> Elts,
-              SmallVector<Value *, 32> &DeadInsts)
-      : OldVal(V), NewElts(Elts), DeadInsts(DeadInsts) {}
+              SmallVector<Value *, 32> &DeadInsts, DxilTypeSystem &ts,
+              const DataLayout &dl)
+      : OldVal(V), NewElts(Elts), DeadInsts(DeadInsts), typeSys(ts), DL(dl) {}
   void RewriteForScalarRepl(Value *V, IRBuilder<> &Builder);
 
 private:
@@ -102,6 +107,8 @@ private:
   // Flattened elements for OldVal.
   ArrayRef<Value*> NewElts;
   SmallVector<Value *, 32> &DeadInsts;
+  DxilTypeSystem  &typeSys;
+  const DataLayout &DL;
 
   void RewriteForConstExpr(ConstantExpr *user, IRBuilder<> &Builder);
   void RewriteForGEP(GEPOperator *GEP, IRBuilder<> &Builder);
@@ -266,7 +273,8 @@ public:
   static void PatchMemCpyWithZeroIdxGEP(MemCpyInst *MI, const DataLayout &DL);
   static void SplitMemCpy(MemCpyInst *MI, const DataLayout &DL,
                           DxilFieldAnnotation *fieldAnnotation,
-                          DxilTypeSystem &typeSys);
+                          DxilTypeSystem &typeSys,
+                          const bool bEltMemCpy = true);
 };
 
 }
@@ -1524,134 +1532,138 @@ bool SROA_HLSL::ShouldAttemptScalarRepl(AllocaInst *AI) {
 bool SROA_HLSL::performScalarRepl(Function &F, DxilTypeSystem &typeSys) {
   std::vector<AllocaInst *> AllocaList;
   const DataLayout &DL = F.getParent()->getDataLayout();
-
-  // Scan the entry basic block, adding allocas to the worklist.
-  BasicBlock &BB = F.getEntryBlock();
-  for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ++I)
-    if (AllocaInst *A = dyn_cast<AllocaInst>(I)) {
-      if (A->hasNUsesOrMore(1))
-        AllocaList.emplace_back(A);
-    }
-
-  // merge GEP use for the allocs
-  for (auto A : AllocaList)
-    HLModule::MergeGepUse(A);
-
   // Make sure big alloca split first.
   // This will simplify memcpy check between part of big alloca and small
   // alloca. Big alloca will be split to smaller piece first, when process the
   // alloca, it will be alloca flattened from big alloca instead of a GEP of big
   // alloca.
   auto size_cmp = [&DL](const AllocaInst *a0, const AllocaInst *a1) -> bool {
-    return DL.getTypeAllocSize(a0->getAllocatedType()) >
+    return DL.getTypeAllocSize(a0->getAllocatedType()) <
            DL.getTypeAllocSize(a1->getAllocatedType());
   };
-
-  std::sort(AllocaList.begin(), AllocaList.end(), size_cmp);
+  std::priority_queue<AllocaInst *, std::vector<AllocaInst *>,
+                      std::function<bool(AllocaInst *, AllocaInst *)>>
+      WorkList(size_cmp);
+  std::unordered_map<AllocaInst*, DbgDeclareInst*> DDIMap;
+  // Scan the entry basic block, adding allocas to the worklist.
+  BasicBlock &BB = F.getEntryBlock();
+  for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ++I)
+    if (AllocaInst *A = dyn_cast<AllocaInst>(I)) {
+      if (!A->user_empty()) {
+        WorkList.push(A);
+        // merge GEP use for the allocs
+        HLModule::MergeGepUse(A);
+        if (DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(A)) {
+          DDIMap[A] = DDI;
+        }
+      }
+    }
 
   DIBuilder DIB(*F.getParent(), /*AllowUnresolved*/ false);
 
   // Process the worklist
   bool Changed = false;
-  for (AllocaInst *Alloc : AllocaList) {
-    DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(Alloc);
-    unsigned debugOffset = 0;
-    std::deque<AllocaInst *> WorkList;
-    WorkList.emplace_back(Alloc);
-    while (!WorkList.empty()) {
-      AllocaInst *AI = WorkList.front();
-      WorkList.pop_front();
+  while (!WorkList.empty()) {
+    AllocaInst *AI = WorkList.top();
+    WorkList.pop();
+
+    // Handle dead allocas trivially.  These can be formed by SROA'ing arrays
+    // with unused elements.
+    if (AI->use_empty()) {
+      AI->eraseFromParent();
+      Changed = true;
+      continue;
+    }
+    const bool bAllowReplace = true;
+    if (SROA_Helper::LowerMemcpy(AI, /*annotation*/ nullptr, typeSys, DL,
+                                 bAllowReplace)) {
+      Changed = true;
+      continue;
+    }
 
-      // Handle dead allocas trivially.  These can be formed by SROA'ing arrays
-      // with unused elements.
-      if (AI->use_empty()) {
-        AI->eraseFromParent();
-        Changed = true;
-        continue;
-      }
-      const bool bAllowReplace = true;
-      if (SROA_Helper::LowerMemcpy(AI, /*annotation*/ nullptr, typeSys, DL,
-                                   bAllowReplace)) {
-        Changed = true;
-        continue;
-      }
+    // If this alloca is impossible for us to promote, reject it early.
+    if (AI->isArrayAllocation() || !AI->getAllocatedType()->isSized())
+      continue;
 
-      // If this alloca is impossible for us to promote, reject it early.
-      if (AI->isArrayAllocation() || !AI->getAllocatedType()->isSized())
-        continue;
+    // Check to see if we can perform the core SROA transformation.  We cannot
+    // transform the allocation instruction if it is an array allocation
+    // (allocations OF arrays are ok though), and an allocation of a scalar
+    // value cannot be decomposed at all.
+    uint64_t AllocaSize = DL.getTypeAllocSize(AI->getAllocatedType());
 
-      // Check to see if we can perform the core SROA transformation.  We cannot
-      // transform the allocation instruction if it is an array allocation
-      // (allocations OF arrays are ok though), and an allocation of a scalar
-      // value cannot be decomposed at all.
-      uint64_t AllocaSize = DL.getTypeAllocSize(AI->getAllocatedType());
+    // Do not promote [0 x %struct].
+    if (AllocaSize == 0)
+      continue;
 
-      // Do not promote [0 x %struct].
-      if (AllocaSize == 0)
-        continue;
+    Type *Ty = AI->getAllocatedType();
+    // Skip empty struct type.
+    if (SROA_Helper::IsEmptyStructType(Ty, typeSys)) {
+      SROA_Helper::MarkEmptyStructUsers(AI, DeadInsts);
+      DeleteDeadInstructions();
+      continue;
+    }
 
-      Type *Ty = AI->getAllocatedType();
-      // Skip empty struct type.
-      if (SROA_Helper::IsEmptyStructType(Ty, typeSys)) {
-        SROA_Helper::MarkEmptyStructUsers(AI, DeadInsts);
-        DeleteDeadInstructions();
-        continue;
-      }
+    // If the alloca looks like a good candidate for scalar replacement, and
+    // if
+    // all its users can be transformed, then split up the aggregate into its
+    // separate elements.
+    if (ShouldAttemptScalarRepl(AI) && isSafeAllocaToScalarRepl(AI)) {
+      std::vector<Value *> Elts;
+      IRBuilder<> Builder(AI);
+      bool hasPrecise = HLModule::HasPreciseAttributeWithMetadata(AI);
 
-      // If the alloca looks like a good candidate for scalar replacement, and
-      // if
-      // all its users can be transformed, then split up the aggregate into its
-      // separate elements.
-      if (ShouldAttemptScalarRepl(AI) && isSafeAllocaToScalarRepl(AI)) {
-        std::vector<Value *> Elts;
-        IRBuilder<> Builder(AI);
-        bool hasPrecise = HLModule::HasPreciseAttributeWithMetadata(AI);
-
-        bool SROAed = SROA_Helper::DoScalarReplacement(
-            AI, Elts, Builder, /*bFlatVector*/ true, hasPrecise, typeSys,
-            DeadInsts);
-
-        if (SROAed) {
-          Type *Ty = AI->getAllocatedType();
-          // Skip empty struct parameters.
-          if (StructType *ST = dyn_cast<StructType>(Ty)) {
-            if (!HLMatrixLower::IsMatrixType(Ty)) {
-              DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ST);
-              if (SA && SA->IsEmptyStruct()) {
-                for (User *U : AI->users()) {
-                  if (StoreInst *SI = dyn_cast<StoreInst>(U))
-                    DeadInsts.emplace_back(SI);
-                }
-                DeleteDeadInstructions();
-                AI->replaceAllUsesWith(UndefValue::get(AI->getType()));
-                AI->eraseFromParent();
-                continue;
+      bool SROAed = SROA_Helper::DoScalarReplacement(
+          AI, Elts, Builder, /*bFlatVector*/ true, hasPrecise, typeSys, DL,
+          DeadInsts);
+
+      if (SROAed) {
+        Type *Ty = AI->getAllocatedType();
+        // Skip empty struct parameters.
+        if (StructType *ST = dyn_cast<StructType>(Ty)) {
+          if (!HLMatrixLower::IsMatrixType(Ty)) {
+            DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ST);
+            if (SA && SA->IsEmptyStruct()) {
+              for (User *U : AI->users()) {
+                if (StoreInst *SI = dyn_cast<StoreInst>(U))
+                  DeadInsts.emplace_back(SI);
               }
+              DeleteDeadInstructions();
+              AI->replaceAllUsesWith(UndefValue::get(AI->getType()));
+              AI->eraseFromParent();
+              continue;
             }
           }
+        }
 
-          // Push Elts into workList.
-          for (auto iter = Elts.begin(); iter != Elts.end(); iter++)
-            WorkList.emplace_back(cast<AllocaInst>(*iter));
-
-          // Now erase any instructions that were made dead while rewriting the
-          // alloca.
-          DeleteDeadInstructions();
-          ++NumReplaced;
-          AI->eraseFromParent();
-          Changed = true;
-          continue;
+        DbgDeclareInst *DDI = nullptr;
+        unsigned debugOffset = 0;
+        auto iter = DDIMap.find(AI);
+        if (iter != DDIMap.end()) {
+          DDI = iter->second;
+        }
+        // Push Elts into workList.
+        for (auto iter = Elts.begin(); iter != Elts.end(); iter++) {
+          AllocaInst *Elt = cast<AllocaInst>(*iter);
+          WorkList.push(Elt);
+          if (DDI) {
+            Type *Ty = Elt->getAllocatedType();
+            unsigned size = DL.getTypeAllocSize(Ty);
+            DIExpression *DDIExp =
+                DIB.createBitPieceExpression(debugOffset, size);
+            debugOffset += size;
+            DbgDeclareInst *EltDDI = cast<DbgDeclareInst>(DIB.insertDeclare(
+                Elt, DDI->getVariable(), DDIExp, DDI->getDebugLoc(), DDI));
+            DDIMap[Elt] = EltDDI;
+          }
         }
-      }
 
-      // Add debug info.
-      if (DDI != nullptr && AI != Alloc) {
-        Type *Ty = AI->getAllocatedType();
-        unsigned size = DL.getTypeAllocSize(Ty);
-        DIExpression *DDIExp = DIB.createBitPieceExpression(debugOffset, size);
-        debugOffset += size;
-        DIB.insertDeclare(AI, DDI->getVariable(), DDIExp, DDI->getDebugLoc(),
-                          DDI);
+        // Now erase any instructions that were made dead while rewriting the
+        // alloca.
+        DeleteDeadInstructions();
+        ++NumReplaced;
+        AI->eraseFromParent();
+        Changed = true;
+        continue;
       }
     }
   }
@@ -2184,18 +2196,61 @@ static void SimpleCopy(Value *Dest, Value *Src,
   else
     SimpleValCopy(Dest, Src, idxList, Builder);
 }
+
+static Value *CreateMergedGEP(Value *Ptr, SmallVector<Value *, 16> &idxList,
+                              IRBuilder<> &Builder) {
+  if (GEPOperator *GEPPtr = dyn_cast<GEPOperator>(Ptr)) {
+    SmallVector<Value *, 2> IdxList(GEPPtr->idx_begin(), GEPPtr->idx_end());
+    // skip idxLIst.begin() because it is included in GEPPtr idx.
+    IdxList.append(idxList.begin() + 1, idxList.end());
+    return Builder.CreateInBoundsGEP(GEPPtr->getPointerOperand(), IdxList);
+  } else {
+    return Builder.CreateInBoundsGEP(Ptr, idxList);
+  }
+}
+
+static void EltMemCpy(Type *Ty, Value *Dest, Value *Src,
+                      SmallVector<Value *, 16> &idxList, IRBuilder<> &Builder,
+                      const DataLayout &DL) {
+  Value *DestGEP = CreateMergedGEP(Dest, idxList, Builder);
+  Value *SrcGEP = CreateMergedGEP(Src, idxList, Builder);
+  unsigned size = DL.getTypeAllocSize(Ty);
+  Builder.CreateMemCpy(DestGEP, SrcGEP, size, size);
+}
+
+static bool IsMemCpyTy(Type *Ty, DxilTypeSystem &typeSys) {
+  if (!Ty->isAggregateType())
+    return false;
+  if (HLMatrixLower::IsMatrixType(Ty))
+    return false;
+  if (HLModule::IsHLSLObjectType(Ty))
+    return false;
+  if (StructType *ST = dyn_cast<StructType>(Ty)) {
+    DxilStructAnnotation *STA = typeSys.GetStructAnnotation(ST);
+    DXASSERT(STA, "require annotation here");
+    if (STA->IsEmptyStruct())
+      return false;
+    // Skip 1 element struct which the element is basic type.
+    // Because create memcpy will create gep on the struct, memcpy the basic
+    // type only.
+    if (ST->getNumElements() == 1)
+      return IsMemCpyTy(ST->getElementType(0), typeSys);
+  }
+  return true;
+}
+
 // Split copy into ld/st.
 static void SplitCpy(Type *Ty, Value *Dest, Value *Src,
                      SmallVector<Value *, 16> &idxList, IRBuilder<> &Builder,
-                     DxilTypeSystem &typeSys,
-                     DxilFieldAnnotation *fieldAnnotation) {
+                     const DataLayout &DL, DxilTypeSystem &typeSys,
+                     DxilFieldAnnotation *fieldAnnotation, const bool bEltMemCpy = true) {
   if (PointerType *PT = dyn_cast<PointerType>(Ty)) {
     Constant *idx = Constant::getIntegerValue(
         IntegerType::get(Ty->getContext(), 32), APInt(32, 0));
     idxList.emplace_back(idx);
 
-    SplitCpy(PT->getElementType(), Dest, Src, idxList, Builder, typeSys,
-             fieldAnnotation);
+    SplitCpy(PT->getElementType(), Dest, Src, idxList, Builder, DL, typeSys,
+             fieldAnnotation, bEltMemCpy);
 
     idxList.pop_back();
   } else if (HLMatrixLower::IsMatrixType(Ty)) {
@@ -2246,12 +2301,16 @@ static void SplitCpy(Type *Ty, Value *Dest, Value *Src,
       return;
     for (uint32_t i = 0; i < ST->getNumElements(); i++) {
       llvm::Type *ET = ST->getElementType(i);
-
       Constant *idx = llvm::Constant::getIntegerValue(
           IntegerType::get(Ty->getContext(), 32), APInt(32, i));
       idxList.emplace_back(idx);
-      DxilFieldAnnotation &EltAnnotation = STA->GetFieldAnnotation(i);
-      SplitCpy(ET, Dest, Src, idxList, Builder, typeSys, &EltAnnotation);
+      if (bEltMemCpy && IsMemCpyTy(ET, typeSys)) {
+        EltMemCpy(ET, Dest, Src, idxList, Builder, DL);
+      } else {
+        DxilFieldAnnotation &EltAnnotation = STA->GetFieldAnnotation(i);
+        SplitCpy(ET, Dest, Src, idxList, Builder, DL, typeSys, &EltAnnotation,
+                 bEltMemCpy);
+      }
 
       idxList.pop_back();
     }
@@ -2263,7 +2322,12 @@ static void SplitCpy(Type *Ty, Value *Dest, Value *Src,
       Constant *idx = Constant::getIntegerValue(
           IntegerType::get(Ty->getContext(), 32), APInt(32, i));
       idxList.emplace_back(idx);
-      SplitCpy(ET, Dest, Src, idxList, Builder, typeSys, fieldAnnotation);
+      if (bEltMemCpy && IsMemCpyTy(ET, typeSys)) {
+        EltMemCpy(ET, Dest, Src, idxList, Builder, DL);
+      } else {
+        SplitCpy(ET, Dest, Src, idxList, Builder, DL, typeSys, fieldAnnotation,
+                 bEltMemCpy);
+      }
 
       idxList.pop_back();
     }
@@ -2372,8 +2436,16 @@ static unsigned MatchSizeByCheckElementType(Type *Ty, const DataLayout &DL, unsi
 static void PatchZeroIdxGEP(Value *Ptr, Value *RawPtr, MemCpyInst *MI,
                             unsigned level, IRBuilder<> &Builder) {
   Value *zeroIdx = Builder.getInt32(0);
-  SmallVector<Value *, 2> IdxList(level + 1, zeroIdx);
-  Value *GEP = Builder.CreateInBoundsGEP(Ptr, IdxList);
+  Value *GEP = nullptr;
+  if (GEPOperator *GEPPtr = dyn_cast<GEPOperator>(Ptr)) {
+    SmallVector<Value *, 2> IdxList(GEPPtr->idx_begin(), GEPPtr->idx_end());
+    // level not + 1 because it is included in GEPPtr idx.
+    IdxList.append(level, zeroIdx);
+    GEP = Builder.CreateInBoundsGEP(GEPPtr->getPointerOperand(), IdxList);
+  } else {
+    SmallVector<Value *, 2> IdxList(level + 1, zeroIdx);
+    GEP = Builder.CreateInBoundsGEP(Ptr, IdxList);
+  }
   // Use BitCastInst::Create to prevent idxList from being optimized.
   CastInst *Cast =
       BitCastInst::Create(Instruction::BitCast, GEP, RawPtr->getType());
@@ -2461,7 +2533,7 @@ static void DeleteMemcpy(MemCpyInst *MI) {
 
 void MemcpySplitter::SplitMemCpy(MemCpyInst *MI, const DataLayout &DL,
                                  DxilFieldAnnotation *fieldAnnotation,
-                                 DxilTypeSystem &typeSys) {
+                                 DxilTypeSystem &typeSys, const bool bEltMemCpy) {
   Value *Dest = MI->getRawDest();
   Value *Src = MI->getRawSource();
   // Only remove one level bitcast generated from inline.
@@ -2489,28 +2561,34 @@ void MemcpySplitter::SplitMemCpy(MemCpyInst *MI, const DataLayout &DL,
   // split
   // Matrix is treated as scalar type, will not use memcpy.
   // So use nullptr for fieldAnnotation should be safe here.
-  SplitCpy(Dest->getType(), Dest, Src, idxList, Builder, typeSys,
-           fieldAnnotation);
+  SplitCpy(Dest->getType(), Dest, Src, idxList, Builder, DL, typeSys,
+           fieldAnnotation, bEltMemCpy);
   // delete memcpy
   DeleteMemcpy(MI);
 }
 
 void MemcpySplitter::Split(llvm::Function &F) {
   const DataLayout &DL = F.getParent()->getDataLayout();
-  // Walk all instruction in the function.
-  for (Function::iterator BB = F.begin(), BBE = F.end(); BB != BBE; ++BB) {
-    for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE;) {
-      // Avoid invalidating the iterator.
-      Instruction *I = BI++;
-
-      if (MemCpyInst *MI = dyn_cast<MemCpyInst>(I)) {
-        // Matrix is treated as scalar type, will not use memcpy.
-        // So use nullptr for fieldAnnotation should be safe here.
-        SplitMemCpy(MI, DL, /*fieldAnnotation*/ nullptr, m_typeSys);
-      }
+
+  Function *memcpy = nullptr;
+  for (Function &Fn : F.getParent()->functions()) {
+    if (Fn.getIntrinsicID() == Intrinsic::memcpy) {
+      memcpy = &Fn;
+      break;
     }
   }
-}
+  if (memcpy) {
+    for (auto U = memcpy->user_begin(); U != memcpy->user_end();) {
+      MemCpyInst *MI = cast<MemCpyInst>(*(U++));
+      if (MI->getParent()->getParent() != &F)
+        continue;
+      // Matrix is treated as scalar type, will not use memcpy.
+      // So use nullptr for fieldAnnotation should be safe here.
+      SplitMemCpy(MI, DL, /*fieldAnnotation*/ nullptr, m_typeSys,
+                  /*bEltMemCpy*/ false);
+    }
+  }
+ }
 
 //===----------------------------------------------------------------------===//
 // SRoA Helper
@@ -2583,7 +2661,14 @@ void SROA_Helper::RewriteForGEP(GEPOperator *GEP, IRBuilder<> &Builder) {
         Value *NewGEP = Builder.CreateGEP(nullptr, NewElts[i], NewArgs);
         NewGEPs.emplace_back(NewGEP);
       }
-      SROA_Helper helper(GEP, NewGEPs, DeadInsts);
+      const bool bAllowReplace = isa<AllocaInst>(OldVal);
+      if (SROA_Helper::LowerMemcpy(GEP, /*annoation*/ nullptr, typeSys, DL,
+                                   bAllowReplace)) {
+        if (GEP->user_empty() && isa<Instruction>(GEP))
+          DeadInsts.push_back(GEP);
+        return;
+      }
+      SROA_Helper helper(GEP, NewGEPs, DeadInsts, typeSys, DL);
       helper.RewriteForScalarRepl(GEP, Builder);
       for (Value *NewGEP : NewGEPs) {
         if (NewGEP->user_empty() && isa<Instruction>(NewGEP)) {
@@ -3111,7 +3196,7 @@ void SROA_Helper::RewriteForAddrSpaceCast(ConstantExpr *CE,
                          CE->getType()->getPointerAddressSpace()));
     NewCasts.emplace_back(NewGEP);
   }
-  SROA_Helper helper(CE, NewCasts, DeadInsts);
+  SROA_Helper helper(CE, NewCasts, DeadInsts, typeSys, DL);
   helper.RewriteForScalarRepl(CE, Builder);
 }
 
@@ -3195,6 +3280,7 @@ static ArrayType *CreateNestArrayTy(Type *FinalEltTy,
 bool SROA_Helper::DoScalarReplacement(Value *V, std::vector<Value *> &Elts,
                                       IRBuilder<> &Builder, bool bFlatVector,
                                       bool hasPrecise, DxilTypeSystem &typeSys,
+                                      const DataLayout &DL,
                                       SmallVector<Value *, 32> &DeadInsts) {
   DEBUG(dbgs() << "Found inst to SROA: " << *V << '\n');
   Type *Ty = V->getType();
@@ -3315,7 +3401,7 @@ bool SROA_Helper::DoScalarReplacement(Value *V, std::vector<Value *> &Elts,
   
   // Now that we have created the new alloca instructions, rewrite all the
   // uses of the old alloca.
-  SROA_Helper helper(V, Elts, DeadInsts);
+  SROA_Helper helper(V, Elts, DeadInsts, typeSys, DL);
   helper.RewriteForScalarRepl(V, Builder);
 
   return true;
@@ -3361,9 +3447,11 @@ static Constant *GetEltInit(Type *Ty, Constant *Init, unsigned idx,
 
 /// DoScalarReplacement - Split V into AllocaInsts with Builder and save the new AllocaInsts into Elts.
 /// Then do SROA on V.
-bool SROA_Helper::DoScalarReplacement(GlobalVariable *GV, std::vector<Value *> &Elts,
+bool SROA_Helper::DoScalarReplacement(GlobalVariable *GV,
+                                      std::vector<Value *> &Elts,
                                       IRBuilder<> &Builder, bool bFlatVector,
                                       bool hasPrecise, DxilTypeSystem &typeSys,
+                                      const DataLayout &DL,
                                       SmallVector<Value *, 32> &DeadInsts) {
   DEBUG(dbgs() << "Found inst to SROA: " << *GV << '\n');
   Type *Ty = GV->getType();
@@ -3503,7 +3591,7 @@ bool SROA_Helper::DoScalarReplacement(GlobalVariable *GV, std::vector<Value *> &
 
   // Now that we have created the new alloca instructions, rewrite all the
   // uses of the old alloca.
-  SROA_Helper helper(GV, Elts, DeadInsts);
+  SROA_Helper helper(GV, Elts, DeadInsts, typeSys, DL);
   helper.RewriteForScalarRepl(GV, Builder);
 
   return true;
@@ -3582,12 +3670,6 @@ struct PointerStatus {
 
 void PointerStatus::analyzePointer(const Value *V, PointerStatus &PS,
                                    DxilTypeSystem &typeSys, bool bStructElt) {
-  if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) {
-    if (GV->hasInitializer() && !isa<UndefValue>(GV->getInitializer())) {
-      PS.StoredType = PointerStatus::StoredType::InitializerStored;
-    }
-  }
-
   for (const User *U : V->users()) {
     if (const Instruction *I = dyn_cast<Instruction>(U)) {
       const Function *F = I->getParent()->getParent();
@@ -3796,6 +3878,90 @@ static void ReplaceMemcpy(Value *V, Value *Src, MemCpyInst *MC) {
   }
 }
 
+static bool ReplaceUseOfZeroInitEntry(Instruction *I, Value *V) {
+  BasicBlock *BB = I->getParent();
+  Function *F = I->getParent()->getParent();
+  for (auto U = V->user_begin(); U != V->user_end(); ) {
+    Instruction *UI = dyn_cast<Instruction>(*(U++));
+    if (!UI)
+      continue;
+
+    if (UI->getParent()->getParent() != F)
+      continue;
+
+    if (isa<GetElementPtrInst>(UI) || isa<BitCastInst>(UI)) {
+      if (!ReplaceUseOfZeroInitEntry(I, UI))
+        return false;
+      else
+        continue;
+    }
+    if (BB != UI->getParent() || UI == I)
+      continue;
+    // I is the last inst in the block after split.
+    // Any inst in current block is before I.
+    if (LoadInst *LI = dyn_cast<LoadInst>(UI)) {
+      LI->replaceAllUsesWith(ConstantAggregateZero::get(LI->getType()));
+      LI->eraseFromParent();
+      continue;
+    }
+    return false;
+  }
+  return true;
+}
+
+static bool ReplaceUseOfZeroInitPostDom(Instruction *I, Value *V,
+                                    PostDominatorTree &PDT) {
+  BasicBlock *BB = I->getParent();
+  Function *F = I->getParent()->getParent();
+  for (auto U = V->user_begin(); U != V->user_end(); ) {
+    Instruction *UI = dyn_cast<Instruction>(*(U++));
+    if (!UI)
+      continue;
+    if (UI->getParent()->getParent() != F)
+      continue;
+
+    if (!PDT.dominates(BB, UI->getParent()))
+      return false;
+
+    if (isa<GetElementPtrInst>(UI) || isa<BitCastInst>(UI)) {
+      if (!ReplaceUseOfZeroInitPostDom(I, UI, PDT))
+        return false;
+      else
+        continue;
+    }
+
+    if (BB != UI->getParent() || UI == I)
+      continue;
+    // I is the last inst in the block after split.
+    // Any inst in current block is before I.
+    if (LoadInst *LI = dyn_cast<LoadInst>(UI)) {
+      LI->replaceAllUsesWith(ConstantAggregateZero::get(LI->getType()));
+      LI->eraseFromParent();
+      continue;
+    }
+    return false;
+  }
+  return true;
+}
+// When zero initialized GV has only one define, all uses before the def should
+// use zero.
+static bool ReplaceUseOfZeroInitBeforeDef(Instruction *I, GlobalVariable *GV) {
+  BasicBlock *BB = I->getParent();
+  Function *F = I->getParent()->getParent();
+  // Make sure I is the last inst for BB.
+  if (I != BB->getTerminator())
+    BB->splitBasicBlock(I->getNextNode());
+
+  if (&F->getEntryBlock() == I->getParent()) {
+    return ReplaceUseOfZeroInitEntry(I, GV);
+  } else {
+    // Post dominator tree.
+    PostDominatorTree PDT;
+    PDT.runOnFunction(*F);
+    return ReplaceUseOfZeroInitPostDom(I, GV, PDT);
+  }
+}
+
 bool SROA_Helper::LowerMemcpy(Value *V, DxilFieldAnnotation *annotation,
                               DxilTypeSystem &typeSys, const DataLayout &DL,
                               bool bAllowReplace) {
@@ -3810,6 +3976,32 @@ bool SROA_Helper::LowerMemcpy(Value *V, DxilFieldAnnotation *annotation,
   PointerStatus PS(size);
   const bool bStructElt = false;
   PointerStatus::analyzePointer(V, PS, typeSys, bStructElt);
+
+  if (GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) {
+    if (GV->hasInitializer() && !isa<UndefValue>(GV->getInitializer())) {
+      if (PS.StoredType == PointerStatus::StoredType::NotStored) {
+        PS.StoredType = PointerStatus::StoredType::InitializerStored;
+      } else if (PS.StoredType == PointerStatus::StoredType::MemcopyDestOnce) {
+        // For single mem store, if the store not dominator all users.
+        // Makr it as Stored.
+        // Case like:
+        // struct A { float4 x[25]; };
+        // A a;
+        // static A a2;
+        // void set(A aa) { aa = a; }
+        // call set inside entry function then use a2.
+        if (isa<ConstantAggregateZero>(GV->getInitializer())) {
+          Instruction * Memcpy = PS.StoringMemcpy;
+          if (!ReplaceUseOfZeroInitBeforeDef(Memcpy, GV)) {
+            PS.StoredType = PointerStatus::StoredType::Stored;
+          }
+        }
+      } else {
+        PS.StoredType = PointerStatus::StoredType::Stored;
+      }
+    }
+  }
+
   if (bAllowReplace && !PS.HasMultipleAccessingFunctions) {
     if (PS.StoredType == PointerStatus::StoredType::MemcopyDestOnce &&
         // Skip argument for input argument has input value, it is not dest once anymore.
@@ -4199,8 +4391,7 @@ void SROA_Parameter_HLSL::flattenGlobal(GlobalVariable *GV) {
     bool SROAed = SROA_Helper::DoScalarReplacement(
         EltGV, Elts, Builder, bFlatVector,
         // TODO: set precise.
-        /*hasPrecise*/ false,
-        dxilTypeSys, DeadInsts);
+        /*hasPrecise*/ false, dxilTypeSys, DL, DeadInsts);
 
     if (SROAed) {
       // Push Elts into workList.
@@ -5176,7 +5367,7 @@ void SROA_Parameter_HLSL::flattenArgument(
     // Not flat vector for entry function currently.
     bool SROAed = SROA_Helper::DoScalarReplacement(
         V, Elts, Builder, /*bFlatVector*/ false, annotation.IsPrecise(),
-        dxilTypeSys, DeadInsts);
+        dxilTypeSys, DL, DeadInsts);
 
     if (SROAed) {
       Type *Ty = V->getType()->getPointerElementType();
@@ -5378,7 +5569,7 @@ void SROA_Parameter_HLSL::flattenArgument(
                 IRBuilder<> Builder(CI);
 
                 llvm::SmallVector<llvm::Value *, 16> idxList;
-                SplitCpy(data->getType(), outputVal, data, idxList, Builder,
+                SplitCpy(data->getType(), outputVal, data, idxList, Builder, DL,
                          dxilTypeSys, &flatParamAnnotation);
 
                 CI->setArgOperand(HLOperandIndex::kStreamAppendDataOpIndex, outputVal);
@@ -5405,7 +5596,7 @@ void SROA_Parameter_HLSL::flattenArgument(
 
                   llvm::SmallVector<llvm::Value *, 16> idxList;
                   SplitCpy(DataPtr->getType(), EltPtr, DataPtr, idxList,
-                           Builder, dxilTypeSys, &flatParamAnnotation);
+                           Builder, DL, dxilTypeSys, &flatParamAnnotation);
                   CI->setArgOperand(i, EltPtr);
                 }
               }
@@ -5562,7 +5753,8 @@ void SROA_Parameter_HLSL::moveFunctionBody(Function *F, Function *flatF) {
   }
 }
 
-static void SplitArrayCopy(Value *V, DxilTypeSystem &typeSys,
+static void SplitArrayCopy(Value *V, const DataLayout &DL,
+                           DxilTypeSystem &typeSys,
                            DxilFieldAnnotation *fieldAnnotation) {
   for (auto U = V->user_begin(); U != V->user_end();) {
     User *user = *(U++);
@@ -5571,7 +5763,7 @@ static void SplitArrayCopy(Value *V, DxilTypeSystem &typeSys,
       Value *val = ST->getValueOperand();
       IRBuilder<> Builder(ST);
       SmallVector<Value *, 16> idxList;
-      SplitCpy(ptr->getType(), ptr, val, idxList, Builder, typeSys,
+      SplitCpy(ptr->getType(), ptr, val, idxList, Builder, DL, typeSys,
                fieldAnnotation);
       ST->eraseFromParent();
     }
@@ -5614,6 +5806,7 @@ static void CheckArgUsage(Value *V, bool &bLoad, bool &bStore) {
 // Support store to input and load from output.
 static void LegalizeDxilInputOutputs(Function *F,
                                      DxilFunctionAnnotation *EntryAnnotation,
+                                     const DataLayout &DL,
                                      DxilTypeSystem &typeSys) {
   BasicBlock &EntryBlk = F->getEntryBlock();
   Module *M = F->getParent();
@@ -5710,7 +5903,7 @@ static void LegalizeDxilInputOutputs(Function *F,
       if (bStoreInputToTemp) {
         llvm::SmallVector<llvm::Value *, 16> idxList;
         // split copy.
-        SplitCpy(temp->getType(), temp, &arg, idxList, Builder, typeSys,
+        SplitCpy(temp->getType(), temp, &arg, idxList, Builder, DL, typeSys,
                  &paramAnnotation);
       }
 
@@ -5740,7 +5933,7 @@ static void LegalizeDxilInputOutputs(Function *F,
         else
           onlyRetBlk = true;
         // split copy.
-        SplitCpy(output->getType(), output, temp, idxList, Builder, typeSys,
+        SplitCpy(output->getType(), output, temp, idxList, Builder, DL, typeSys,
                  &paramAnnotation);
       }
       // Clone the return.
@@ -5752,7 +5945,7 @@ static void LegalizeDxilInputOutputs(Function *F,
 
 void SROA_Parameter_HLSL::createFlattenedFunction(Function *F) {
   DxilTypeSystem &typeSys = m_pHLModule->GetTypeSystem();
-
+  const DataLayout &DL = m_pHLModule->GetModule()->getDataLayout();
   // Skip void (void) function.
   if (F->getReturnType()->isVoidTy() && F->getArgumentList().empty()) {
     return;
@@ -5933,7 +6126,7 @@ void SROA_Parameter_HLSL::createFlattenedFunction(Function *F) {
     }
     if (!F->isDeclaration()) {
       // Support store to input and load from output.
-      LegalizeDxilInputOutputs(F, funcAnnotation, typeSys);
+      LegalizeDxilInputOutputs(F, funcAnnotation, DL, typeSys);
     }
     return;
   }
@@ -6074,12 +6267,12 @@ void SROA_Parameter_HLSL::createFlattenedFunction(Function *F) {
         Type *Ty = Arg->getType()->getPointerElementType();
         if (Ty->isArrayTy())
           SplitArrayCopy(
-              Arg, typeSys,
+              Arg, DL, typeSys,
               &flatFuncAnnotation->GetParameterAnnotation(Arg->getArgNo()));
       }
     }
     // Support store to input and load from output.
-    LegalizeDxilInputOutputs(flatF, flatFuncAnnotation, typeSys);
+    LegalizeDxilInputOutputs(flatF, flatFuncAnnotation, DL, typeSys);
   }
 }
 
@@ -6092,6 +6285,7 @@ void SROA_Parameter_HLSL::createFlattenedFunctionCall(Function *F, Function *fla
   vectorEltsMap.clear();
 
   DxilTypeSystem &typeSys = m_pHLModule->GetTypeSystem();
+  const DataLayout &DL = m_pHLModule->GetModule()->getDataLayout();
 
   std::vector<Value *> FlatParamList;
   std::vector<DxilParameterAnnotation> FlatParamAnnotationList;
@@ -6179,7 +6373,7 @@ void SROA_Parameter_HLSL::createFlattenedFunctionCall(Function *F, Function *fla
         // Copy in param.
         llvm::SmallVector<llvm::Value *, 16> idxList;
         // split copy to avoid load of struct.
-        SplitCpy(Ty, tempArg, arg, idxList, CallBuilder, typeSys,
+        SplitCpy(Ty, tempArg, arg, idxList, CallBuilder, DL, typeSys,
                  &paramAnnotation);
       }
 
@@ -6188,7 +6382,7 @@ void SROA_Parameter_HLSL::createFlattenedFunctionCall(Function *F, Function *fla
         // Copy out param.
         llvm::SmallVector<llvm::Value *, 16> idxList;
         // split copy to avoid load of struct.
-        SplitCpy(Ty, arg, tempArg, idxList, RetBuilder, typeSys,
+        SplitCpy(Ty, arg, tempArg, idxList, RetBuilder, DL, typeSys,
                  &paramAnnotation);
       }
       arg = tempArg;
@@ -7055,7 +7249,8 @@ void ResourceToHandle::ReplaceResourceWithHandle(Value *ResPtr,
       // Remove resource Store.
       SI->eraseFromParent();
     } else {
-      DXASSERT(0, "invalid operation on resource");
+      if (!U->user_empty() || !isa<GEPOperator>(U))
+        DXASSERT(0, "invalid operation on resource");
     }
   }
 }

+ 4 - 1
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -2921,6 +2921,9 @@ static bool CreateCBufferVariable(HLCBuffer &CB,
     if (cbSubscript->user_empty()) {
       cbSubscript->eraseFromParent();
       Handle->eraseFromParent();
+    } else {
+      // merge GEP use for cbSubscript.
+      HLModule::MergeGepUse(cbSubscript);
     }
   }
   return true;
@@ -4299,7 +4302,7 @@ void CGMSHLSLRuntime::FinishCodeGen() {
     if (f.hasFnAttribute(llvm::Attribute::NoInline))
       continue;
     // Always inline for used functions.
-    if (!f.user_empty())
+    if (!f.user_empty() && !f.isDeclaration())
       f.addFnAttr(llvm::Attribute::AlwaysInline);
   }
 

+ 23 - 0
tools/clang/test/CodeGenHLSL/quick-test/static_global_copy.hlsl

@@ -0,0 +1,23 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+
+
+// Make sure initialize static global inside user function can still be propagated.
+// CHECK-NOT: alloca
+
+struct A {
+  float4 x[25];
+};
+
+A a;
+
+static A a2;
+
+void set(A aa) {
+   aa = a;
+}
+
+float4 main(uint l:L) : SV_Target {
+  set(a2);
+  return a2.x[l];
+}

+ 30 - 0
tools/clang/test/CodeGenHLSL/quick-test/static_global_copy2.hlsl

@@ -0,0 +1,30 @@
+// RUN: %dxc -E main -T ps_6_0 -Zi %s | FileCheck %s
+
+
+// Make sure debug info works for flattened alloca.
+// CHECK:call void @llvm.dbg.declare(metadata [2 x float]* %a2.1, 
+
+struct X {
+   float a;
+   int b;
+};
+
+struct A {
+  X x[25];
+  float y[2];
+};
+
+A a;
+float b;
+
+void set(A aa) {
+   aa = a;
+   aa.y[0] = b;
+   aa.y[1] = 3;
+}
+
+float4 main(uint l:L) : SV_Target {
+  A a2;
+  set(a2);
+  return a2.x[l].a + a2.y[l];
+}

+ 28 - 0
tools/clang/test/CodeGenHLSL/quick-test/static_global_copy3.hlsl

@@ -0,0 +1,28 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// Make sure initialize static global inside user function can still be propagated.
+// CHECK-NOT: alloca
+
+// Make sure cbuffer is used.
+// CHECK: call %dx.types.CBufRet.f32 @dx.op.cbufferLoad
+
+// Make sure use of zero initializer get zero.
+// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 1, float 0.000000e+00)
+
+struct A {
+  float4 x[25];
+};
+
+A a;
+
+static A a2;
+
+void set(A aa) {
+   aa = a;
+}
+
+float2 main(uint l:L) : SV_Target {
+  float m = a2.x[l].x;
+  set(a2);
+  return float2(a2.x[l].x,m);
+}