فهرست منبع

Remove copy be replace dst with src. (#296)

1. Generate memcpy in EmitHLSLFlatConversionAggregateCopy if type is match.
2. If value is used as Src of memcpy, mark load instead of store.
3. Do not do replace on GEP and bitcast.
4. Delete inst instead of push into DeadInsts when safe.
     This will prevent same inst added into DeadInsts more than once.
Xiang Li 8 سال پیش
والد
کامیت
6cca464f51

+ 1 - 1
include/dxc/HLSL/HLOperations.h

@@ -118,7 +118,7 @@ HLOpcodeGroup GetHLOpcodeGroup(llvm::Function *F);
 HLOpcodeGroup GetHLOpcodeGroupByName(const llvm::Function *F);
 llvm::StringRef GetHLOpcodeGroupNameByAttr(llvm::Function *F);
 llvm::StringRef GetHLLowerStrategy(llvm::Function *F);
-unsigned  GetHLOpcode(llvm::CallInst *CI);
+unsigned  GetHLOpcode(const llvm::CallInst *CI);
 unsigned  GetRowMajorOpcode(HLOpcodeGroup group, unsigned opcode);
 void SetHLLowerStrategy(llvm::Function *F, llvm::StringRef S);
 

+ 1 - 1
lib/HLSL/HLOperations.cpp

@@ -313,7 +313,7 @@ std::string GetHLFullName(HLOpcodeGroup op, unsigned opcode) {
 }
 
 // Get opcode from arg0 of function call.
-unsigned  GetHLOpcode(CallInst *CI) {
+unsigned  GetHLOpcode(const CallInst *CI) {
   Value *idArg = CI->getArgOperand(HLOperandIndex::kOpcodeIdx);
   Constant *idConst = cast<Constant>(idArg);
   return idConst->getUniqueInteger().getLimitedValue();

+ 165 - 26
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -2379,17 +2379,33 @@ void MemcpySplitter::PatchMemCpyWithZeroIdxGEP(MemCpyInst *MI,
     Src = BC->getOperand(0);
 
   IRBuilder<> Builder(MI);
+  ConstantInt *zero = Builder.getInt32(0);
   Type *DestTy = Dest->getType()->getPointerElementType();
   Type *SrcTy = Src->getType()->getPointerElementType();
   // Support case when bitcast (gep ptr, 0,0) is transformed into
   // bitcast ptr.
+  // Also replace (gep ptr, 0) with ptr.
   ConstantInt *Length = cast<ConstantInt>(MI->getLength());
   unsigned size = Length->getLimitedValue();
   if (unsigned level = MatchSizeByCheckElementType(DestTy, DL, size, 0)) {
     PatchZeroIdxGEP(Dest, MI->getRawDest(), MI, level, Builder);
+  } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(Dest)) {
+    if (GEP->getNumIndices() == 1) {
+       Value *idx = *GEP->idx_begin();
+       if (idx == zero) {
+         GEP->replaceAllUsesWith(GEP->getPointerOperand());
+       }
+    }
   }
   if (unsigned level = MatchSizeByCheckElementType(SrcTy, DL, size, 0)) {
     PatchZeroIdxGEP(Src, MI->getRawSource(), MI, level, Builder);
+  } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(Src)) {
+    if (GEP->getNumIndices() == 1) {
+      Value *idx = *GEP->idx_begin();
+      if (idx == zero) {
+        GEP->replaceAllUsesWith(GEP->getPointerOperand());
+      }
+    }
   }
 }
 
@@ -2541,9 +2557,11 @@ void SROA_Helper::RewriteForGEP(GEPOperator *GEP, IRBuilder<> &Builder) {
       for (Value *NewGEP : NewGEPs) {
         if (NewGEP->user_empty() && isa<Instruction>(NewGEP)) {
           // Delete unused newGEP.
-          DeadInsts.emplace_back(NewGEP);
+          cast<Instruction>(NewGEP)->eraseFromParent();
         }
       }
+      if (GEP->user_empty() && isa<Instruction>(GEP))
+        DeadInsts.push_back(GEP);
     } else {
       Value *vecIdx = NewArgs.back();
       if (ConstantInt *immVecIdx = dyn_cast<ConstantInt>(vecIdx)) {
@@ -3429,7 +3447,7 @@ bool SROA_Helper::DoScalarReplacement(GlobalVariable *GV, std::vector<Value *> &
 struct PointerStatus {
   /// Keep track of what stores to the pointer look like.
   enum StoredType {
-    /// There is no store to this global.  It can thus be marked constant.
+    /// There is no store to this pointer.  It can thus be marked constant.
     NotStored,
 
     /// This ptr is a global, and is stored to, but the only thing stored is the
@@ -3450,7 +3468,18 @@ struct PointerStatus {
     /// cannot track.
     Stored
   } StoredType;
+  /// Keep track of what loaded from the pointer look like.
+  enum LoadedType {
+    /// There is no load to this pointer.  It can thus be marked constant.
+    NotLoaded,
 
+    /// This ptr is only used by a memcpy.
+    MemcopySrcOnce,
+
+    /// This ptr is loaded to by multiple instructions or something else that we
+    /// cannot track.
+    Loaded
+  } LoadedType;
   /// If only one value (besides the initializer constant) is ever stored to
   /// this global, keep track of what value it is.
   Value *StoredOnceValue;
@@ -3458,6 +3487,8 @@ struct PointerStatus {
   std::vector<MemCpyInst *> memcpyList;
   /// Memcpy which use this ptr as dest.
   MemCpyInst *StoringMemcpy;
+  /// Memcpy which use this ptr as src.
+  MemCpyInst *LoadingMemcpy;
   /// These start out null/false.  When the first accessing function is noticed,
   /// it is recorded. When a second different accessing function is noticed,
   /// HasMultipleAccessingFunctions is set to true.
@@ -3473,13 +3504,15 @@ struct PointerStatus {
                              DxilTypeSystem &typeSys, bool bStructElt);
 
   PointerStatus(unsigned size)
-      : StoredType(NotStored), StoredOnceValue(nullptr), StoringMemcpy(nullptr),
+      : StoredType(NotStored), LoadedType(NotLoaded), StoredOnceValue(nullptr),
+        StoringMemcpy(nullptr), LoadingMemcpy(nullptr),
         AccessingFunction(nullptr), HasMultipleAccessingFunctions(false),
         Size(size) {}
   void MarkAsStored() {
     StoredType = PointerStatus::StoredType::Stored;
     StoredOnceValue = nullptr;
   }
+  void MarkAsLoaded() { LoadedType = PointerStatus::LoadedType::Loaded; }
 };
 
 void PointerStatus::analyzePointer(const Value *V, PointerStatus &PS,
@@ -3521,9 +3554,23 @@ void PointerStatus::analyzePointer(const Value *V, PointerStatus &PS,
             PS.MarkAsStored();
             PS.StoringMemcpy = nullptr;
           }
+        } else if (MC->getRawSource() == V) {
+          if (bFullCopy &&
+              PS.LoadedType == PointerStatus::LoadedType::NotLoaded) {
+            PS.LoadedType = PointerStatus::LoadedType::MemcopySrcOnce;
+            PS.LoadingMemcpy = PS.memcpyList.back();
+          } else {
+            PS.MarkAsLoaded();
+            PS.LoadingMemcpy = nullptr;
+          }
         }
       } else {
-        PS.MarkAsStored();
+        if (MC->getRawDest() == V) {
+          PS.MarkAsStored();
+        } else {
+          DXASSERT(MC->getRawSource() == V, "must be source here");
+          PS.MarkAsLoaded();
+        }
       }
     } else if (const GEPOperator *GEP = dyn_cast<GEPOperator>(U)) {
       gep_type_iterator GEPIt = gep_type_begin(GEP);
@@ -3542,12 +3589,56 @@ void PointerStatus::analyzePointer(const Value *V, PointerStatus &PS,
       } else {
         PS.MarkAsStored();
       }
+    } else if (const LoadInst *LI = dyn_cast<LoadInst>(U)) {
+      PS.MarkAsLoaded();
     } else if (const CallInst *CI = dyn_cast<CallInst>(U)) {
       Function *F = CI->getCalledFunction();
       DxilFunctionAnnotation *annotation = typeSys.GetFunctionAnnotation(F);
       if (!annotation) {
-        // If not sure its out param or not. Take as out param.
-        PS.MarkAsStored();
+        HLOpcodeGroup group = hlsl::GetHLOpcodeGroupByName(F);
+        switch (group) {
+        case HLOpcodeGroup::HLMatLoadStore: {
+          HLMatLoadStoreOpcode opcode =
+              static_cast<HLMatLoadStoreOpcode>(hlsl::GetHLOpcode(CI));
+          switch (opcode) {
+          case HLMatLoadStoreOpcode::ColMatLoad:
+          case HLMatLoadStoreOpcode::RowMatLoad:
+            PS.MarkAsLoaded();
+            break;
+          case HLMatLoadStoreOpcode::ColMatStore:
+          case HLMatLoadStoreOpcode::RowMatStore:
+            PS.MarkAsStored();
+            break;
+          default:
+            DXASSERT(0, "invalid opcode");
+            PS.MarkAsStored();
+            PS.MarkAsLoaded();
+          }
+        } break;
+        case HLOpcodeGroup::HLSubscript: {
+          HLSubscriptOpcode opcode =
+              static_cast<HLSubscriptOpcode>(hlsl::GetHLOpcode(CI));
+          switch (opcode) {
+          case HLSubscriptOpcode::VectorSubscript:
+          case HLSubscriptOpcode::ColMatElement:
+          case HLSubscriptOpcode::ColMatSubscript:
+          case HLSubscriptOpcode::RowMatElement:
+          case HLSubscriptOpcode::RowMatSubscript:
+            analyzePointer(CI, PS, typeSys, bStructElt);
+            break;
+          default:
+            // Rest are resource ptr like buf[i].
+            // Only read of resource handle.
+            PS.MarkAsLoaded();
+            break;
+          }
+        } break;
+        default: {
+          // If not sure its out param or not. Take as out param.
+          PS.MarkAsStored();
+          PS.MarkAsLoaded();
+        }
+        }
         continue;
       }
 
@@ -3559,6 +3650,11 @@ void PointerStatus::analyzePointer(const Value *V, PointerStatus &PS,
               annotation->GetParameterAnnotation(i).GetParamInputQual();
           if (inputQual != DxilParamInputQual::In) {
             PS.MarkAsStored();
+            if (inputQual == DxilParamInputQual::Inout)
+              PS.MarkAsLoaded();
+            break;
+          } else {
+            PS.MarkAsLoaded();
             break;
           }
         }
@@ -3621,26 +3717,69 @@ bool SROA_Helper::LowerMemcpy(Value *V, DxilFieldAnnotation *annotation,
   PointerStatus PS(size);
   const bool bStructElt = false;
   PointerStatus::analyzePointer(V, PS, typeSys, bStructElt);
-  if (bAllowReplace &&
-      PS.StoredType == PointerStatus::StoredType::MemcopyDestOnce &&
-      !PS.HasMultipleAccessingFunctions) {
-    // How to make sure Src is not updated after Memcopy?
-
-    // Replace with src of memcpy.
-    MemCpyInst *MC = PS.StoringMemcpy;
-    if (MC->getSourceAddressSpace() == MC->getDestAddressSpace()) {
-      Value *Src = MC->getOperand(1);
-      // Only remove one level bitcast generated from inline.
-      if (BitCastOperator *BC = dyn_cast<BitCastOperator>(Src))
-        Src = BC->getOperand(0);
-
-      // Need to make sure src not updated after current memcpy.
-      // Check Src only have 1 store now.
-      PointerStatus SrcPS(size);
-      PointerStatus::analyzePointer(Src, SrcPS, typeSys, bStructElt);
-      if (SrcPS.StoredType != PointerStatus::StoredType::Stored) {
-        ReplaceMemcpy(V, Src, MC);
-        return true;
+  if (bAllowReplace && !PS.HasMultipleAccessingFunctions) {
+    if (PS.StoredType == PointerStatus::StoredType::MemcopyDestOnce) {
+      // Replace with src of memcpy.
+      MemCpyInst *MC = PS.StoringMemcpy;
+      if (MC->getSourceAddressSpace() == MC->getDestAddressSpace()) {
+        Value *Src = MC->getOperand(1);
+        // Only remove one level bitcast generated from inline.
+        if (BitCastOperator *BC = dyn_cast<BitCastOperator>(Src))
+          Src = BC->getOperand(0);
+
+        if (GEPOperator *GEP = dyn_cast<GEPOperator>(Src)) {
+          // For GEP, the ptr could have other GEP read/write.
+          // Only scan one GEP is not enough.
+          Value *Ptr = GEP->getPointerOperand();
+          if (CallInst *PtrCI = dyn_cast<CallInst>(Ptr)) {
+            hlsl::HLOpcodeGroup group =
+                hlsl::GetHLOpcodeGroup(PtrCI->getCalledFunction());
+            if (group == HLOpcodeGroup::HLSubscript) {
+              HLSubscriptOpcode opcode =
+                  static_cast<HLSubscriptOpcode>(hlsl::GetHLOpcode(PtrCI));
+              if (opcode == HLSubscriptOpcode::CBufferSubscript) {
+                // Ptr from CBuffer is safe.
+                ReplaceMemcpy(V, Src, MC);
+                return true;
+              }
+            }
+          }
+        } else if (!isa<CallInst>(Src)) {
+          // Resource ptr should not be replaced.
+          // Need to make sure src not updated after current memcpy.
+          // Check Src only have 1 store now.
+          PointerStatus SrcPS(size);
+          PointerStatus::analyzePointer(Src, SrcPS, typeSys, bStructElt);
+          if (SrcPS.StoredType != PointerStatus::StoredType::Stored) {
+            ReplaceMemcpy(V, Src, MC);
+            return true;
+          }
+        }
+      }
+    } else if (PS.LoadedType == PointerStatus::LoadedType::MemcopySrcOnce) {
+      // Replace dst of memcpy.
+      MemCpyInst *MC = PS.LoadingMemcpy;
+      if (MC->getSourceAddressSpace() == MC->getDestAddressSpace()) {
+        Value *Dest = MC->getOperand(0);
+        // Only remove one level bitcast generated from inline.
+        if (BitCastOperator *BC = dyn_cast<BitCastOperator>(Dest))
+          Dest = BC->getOperand(0);
+        // For GEP, the ptr could have other GEP read/write.
+        // Only scan one GEP is not enough.
+        // And resource ptr should not be replaced.
+        if (!isa<GEPOperator>(Dest) && !isa<CallInst>(Dest) &&
+            !isa<BitCastOperator>(Dest)) {
+          // Need to make sure Dest not updated after current memcpy.
+          // Check Dest only have 1 store now.
+          PointerStatus DestPS(size);
+          PointerStatus::analyzePointer(Dest, DestPS, typeSys, bStructElt);
+          if (DestPS.StoredType != PointerStatus::StoredType::Stored) {
+            ReplaceMemcpy(Dest, V, MC);
+            // V still need to be flatten.
+            // Lower memcpy come from Dest.
+            return LowerMemcpy(V, annotation, typeSys, DL, bAllowReplace);
+          }
+        }
       }
     }
   }

+ 29 - 16
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -3463,6 +3463,8 @@ static void SimplifyBitCast(BitCastOperator *BC, std::vector<Instruction *> &dea
         }
     } else if (CallInst *CI = dyn_cast<CallInst>(U)) {
       // Skip function call.
+    } else if (BitCastInst *Cast = dyn_cast<BitCastInst>(U)) {
+      // Skip bitcast.
     } else {
       DXASSERT(0, "not support yet");
     }
@@ -5395,23 +5397,34 @@ void CGMSHLSLRuntime::EmitHLSLFlatConversionAggregateCopy(CodeGenFunction &CGF,
     clang::QualType SrcTy,
     llvm::Value *DestPtr,
     clang::QualType DestTy) {
-    // It is possiable to implement EmitHLSLAggregateCopy, EmitHLSLAggregateStore the same way.
-    // But split value to scalar will generate many instruction when src type is same as dest type.
-    SmallVector<Value *, 4> idxList;
-    SmallVector<Value *, 4> SrcGEPList;
-    SmallVector<QualType, 4> SrcEltTyList;
-    FlattenAggregatePtrToGepList(CGF, SrcPtr, idxList, SrcTy, SrcPtr->getType(), SrcGEPList,
+  llvm::Type *SrcPtrTy = SrcPtr->getType()->getPointerElementType();
+  llvm::Type *DestPtrTy = DestPtr->getType()->getPointerElementType();
+  if (SrcPtrTy == DestPtrTy) {
+    // Memcpy if type is match.
+    unsigned size = TheModule.getDataLayout().getTypeAllocSize(SrcPtrTy);
+    CGF.Builder.CreateMemCpy(DestPtr, SrcPtr, size, 1);
+    return;
+  }
+  // It is possiable to implement EmitHLSLAggregateCopy, EmitHLSLAggregateStore
+  // the same way. But split value to scalar will generate many instruction when
+  // src type is same as dest type.
+  SmallVector<Value *, 4> idxList;
+  SmallVector<Value *, 4> SrcGEPList;
+  SmallVector<QualType, 4> SrcEltTyList;
+  FlattenAggregatePtrToGepList(CGF, SrcPtr, idxList, SrcTy, SrcPtr->getType(),
+                               SrcGEPList, SrcEltTyList);
+
+  SmallVector<Value *, 4> LdEltList;
+  LoadFlattenedGepList(CGF, SrcGEPList, SrcEltTyList, LdEltList);
+
+  idxList.clear();
+  SmallVector<Value *, 4> DestGEPList;
+  SmallVector<QualType, 4> DestEltTyList;
+  FlattenAggregatePtrToGepList(CGF, DestPtr, idxList, DestTy,
+                               DestPtr->getType(), DestGEPList, DestEltTyList);
+
+  StoreFlattenedGepList(CGF, DestGEPList, DestEltTyList, LdEltList,
                         SrcEltTyList);
-
-    SmallVector<Value *, 4> LdEltList;
-    LoadFlattenedGepList(CGF, SrcGEPList, SrcEltTyList, LdEltList);
-
-    idxList.clear();
-    SmallVector<Value *, 4> DestGEPList;
-    SmallVector<QualType, 4> DestEltTyList;
-    FlattenAggregatePtrToGepList(CGF, DestPtr, idxList, DestTy, DestPtr->getType(), DestGEPList, DestEltTyList);
-
-    StoreFlattenedGepList(CGF, DestGEPList, DestEltTyList, LdEltList, SrcEltTyList);
 }
 
 void CGMSHLSLRuntime::EmitHLSLAggregateStore(CodeGenFunction &CGF, llvm::Value *SrcVal,

+ 29 - 0
tools/clang/test/CodeGenHLSL/cbuffer_copy4.hlsl

@@ -0,0 +1,29 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// Make sure no alloca to copy.
+// CHECK-NOT: alloca
+
+struct M {
+  float4x4  m;
+};
+
+cbuffer T
+{
+	M a[2];
+	float4 b[2];
+}
+struct ST
+{
+	M a[2];
+	float4 b[2];
+};
+
+uint i;
+
+float4 main() : SV_Target
+{
+  ST st;
+  st.a = a;
+  st.b = b;
+  return mul(st.a[i].m, st.b[i]);
+}

+ 1 - 0
tools/clang/test/CodeGenHLSL/class.hlsl

@@ -21,5 +21,6 @@ X x0;
 float4 main(float4 a : A, float4 b:B) : SV_TARGET
 {
   X x = x0;
+  x.n2[0].n = 1;
   return x.test_inout(a.x).n;
 }

+ 5 - 0
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -394,6 +394,7 @@ public:
   TEST_METHOD(CodeGenCbufferCopy)
   TEST_METHOD(CodeGenCbufferCopy2)
   TEST_METHOD(CodeGenCbufferCopy3)
+  TEST_METHOD(CodeGenCbufferCopy4)
   TEST_METHOD(CodeGenCbuffer_unused)
   TEST_METHOD(CodeGenCbuffer1_50)
   TEST_METHOD(CodeGenCbuffer1_51)
@@ -2321,6 +2322,10 @@ TEST_F(CompilerTest, CodeGenCbufferCopy3) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\cbuffer_copy3.hlsl");
 }
 
+TEST_F(CompilerTest, CodeGenCbufferCopy4) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\cbuffer_copy4.hlsl");
+}
+
 TEST_F(CompilerTest, CodeGenCbuffer_unused) {
   CodeGenTest(L"..\\CodeGenHLSL\\cbuffer_unused.hlsl");
 }