Browse Source

Fixed a crash with passing struct array element to functions (#2245)

Adam Yang 6 years ago
parent
commit
bbf44f7a22

+ 106 - 65
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -3247,6 +3247,65 @@ static void ReplaceUnboundedArrayUses(Value *V, Value *Src, IRBuilder<> &Builder
   }
 }
 
+static bool IsUnboundedArrayMemcpy(Type *destTy, Type *srcTy) {
+  return (destTy->isArrayTy() && srcTy->isArrayTy()) &&
+    (destTy->getArrayNumElements() == 0 || srcTy->getArrayNumElements() == 0);
+}
+
+static bool ArePointersToStructsOfIdenticalLayouts(Type *DstTy, Type *SrcTy) {
+  if (!SrcTy->isPointerTy() || !DstTy->isPointerTy())
+    return false;
+  DstTy = DstTy->getPointerElementType();
+  SrcTy = SrcTy->getPointerElementType();
+  if (!SrcTy->isStructTy() || !DstTy->isStructTy())
+    return false;
+  StructType *DstST = cast<StructType>(DstTy);
+  StructType *SrcST = cast<StructType>(SrcTy);
+  return SrcST->isLayoutIdentical(DstST);
+}
+
+static std::vector<Value *> GetConstValueIdxList(IRBuilder<> &builder,
+  std::vector<unsigned> idxlist) {
+  std::vector<Value *> idxConstList;
+  for (unsigned idx : idxlist) {
+    idxConstList.push_back(ConstantInt::get(builder.getInt32Ty(), idx));
+  }
+  return idxConstList;
+}
+
+static void CopyElementsOfStructsWithIdenticalLayout(
+  IRBuilder<> &builder, Value *destPtr, Value *srcPtr, Type *ty,
+  std::vector<unsigned>& idxlist) {
+  if (ty->isStructTy()) {
+    for (unsigned i = 0; i < ty->getStructNumElements(); i++) {
+      idxlist.push_back(i);
+      CopyElementsOfStructsWithIdenticalLayout(
+        builder, destPtr, srcPtr, ty->getStructElementType(i), idxlist);
+      idxlist.pop_back();
+    }
+  }
+  else if (ty->isArrayTy()) {
+    for (unsigned i = 0; i < ty->getArrayNumElements(); i++) {
+      idxlist.push_back(i);
+      CopyElementsOfStructsWithIdenticalLayout(
+        builder, destPtr, srcPtr, ty->getArrayElementType(), idxlist);
+      idxlist.pop_back();
+    }
+  }
+  else if (ty->isIntegerTy() || ty->isFloatTy() || ty->isDoubleTy() ||
+    ty->isHalfTy() || ty->isVectorTy()) {
+    Value *srcGEP =
+      builder.CreateInBoundsGEP(srcPtr, GetConstValueIdxList(builder, idxlist));
+    Value *destGEP =
+      builder.CreateInBoundsGEP(destPtr, GetConstValueIdxList(builder, idxlist));
+    LoadInst *LI = builder.CreateLoad(srcGEP);
+    builder.CreateStore(LI, destGEP);
+  }
+  else {
+    DXASSERT(0, "encountered unsupported type when copying elements of identical structs.");
+  }
+}
+
 static void ReplaceMemcpy(Value *V, Value *Src, MemCpyInst *MC) {
   Type *TyV = V->getType()->getPointerElementType();
   Type *TySrc = Src->getType()->getPointerElementType();
@@ -3268,15 +3327,42 @@ static void ReplaceMemcpy(Value *V, Value *Src, MemCpyInst *MC) {
     if (TyV == TySrc) {
       if (V != Src)
         V->replaceAllUsesWith(Src);
+    } else if (!IsUnboundedArrayMemcpy(TyV, TySrc)) {
+      Value* DestVal = MC->getRawDest();
+      Value* SrcVal = MC->getRawSource();
+      if (!isa<BitCastInst>(SrcVal) || !isa<BitCastInst>(DestVal)) {
+        DXASSERT(0, "Encountered unexpected instruction sequence");
+        return;
+      }
+
+      BitCastInst *DestBCI = cast<BitCastInst>(DestVal);
+      BitCastInst *SrcBCI = cast<BitCastInst>(SrcVal);
+      if (!ArePointersToStructsOfIdenticalLayouts(DestBCI->getSrcTy(), SrcBCI->getSrcTy())) {
+        DXASSERT(0, "Can't handle structs of different layouts");
+        return;
+      }
+
+      const DataLayout &DL = SrcBCI->getModule()->getDataLayout();
+      unsigned SrcSize = DL.getTypeAllocSize(SrcBCI->getOperand(0)->getType()->getPointerElementType());
+      unsigned MemcpySize = cast<ConstantInt>(MC->getLength())->getZExtValue() * MC->getAlignment();
+      if (SrcSize != MemcpySize) {
+        DXASSERT(0, "Cannot handle partial memcpy");
+        return;
+      }
+
+      if (DestBCI->hasOneUse() && SrcBCI->hasOneUse()) {
+        IRBuilder<> Builder(MC);
+        StructType *srcStTy = cast<StructType>(SrcBCI->getOperand(0)->getType()->getPointerElementType());
+        std::vector<unsigned> idxlist = { 0 };
+        CopyElementsOfStructsWithIdenticalLayout(Builder, DestBCI->getOperand(0), SrcBCI->getOperand(0), srcStTy, idxlist);
+      }
     } else {
-      DXASSERT((TyV->isArrayTy() && TySrc->isArrayTy()) &&
-               (TyV->getArrayNumElements() == 0 ||
-                TySrc->getArrayNumElements() == 0),
-               "otherwise mismatched types in memcpy are not unbounded array");
+      DXASSERT(IsUnboundedArrayMemcpy(TyV, TySrc), "otherwise mismatched types in memcpy are not unbounded array");
       IRBuilder<> Builder(MC);
       ReplaceUnboundedArrayUses(V, Src, Builder);
     }
   }
+
   Value *RawDest = MC->getOperand(0);
   Value *RawSrc = MC->getOperand(1);
   MC->eraseFromParent();
@@ -3556,8 +3642,8 @@ public:
   static char ID; // Pass identification, replacement for typeid
   explicit SROA_Parameter_HLSL() : ModulePass(ID) {}
   const char *getPassName() const override { return "SROA Parameter HLSL"; }
-  static void CopyElementsOfStructsWithIdenticalLayout(IRBuilder<>& builder, Value* destPtr, Value* srcPtr, Type *ty, std::vector<unsigned>& idxlist);
   static void RewriteBitcastWithIdenticalStructs(Function *F);
+  static void RewriteBitcastWithIdenticalStructs(BitCastInst *BCI);
 
   bool runOnModule(Module &M) override {
     // Patch memcpy to cover case bitcast (gep ptr, 0,0) is transformed into
@@ -3751,7 +3837,7 @@ private:
     unsigned startArgIndex, llvm::StringMap<Type *> &semanticTypeMap);
   bool hasDynamicVectorIndexing(Value *V);
   void flattenGlobal(GlobalVariable *GV);
-  static std::vector<Value*> GetConstValueIdxList(IRBuilder<>& builder, std::vector<unsigned> idxlist);
+  //static std::vector<Value*> GetConstValueIdxList(IRBuilder<>& builder, std::vector<unsigned> idxlist);
   /// DeadInsts - Keep track of instructions we have made dead, so that
   /// we can remove them after we are done working.
   SmallVector<Value *, 32> DeadInsts;
@@ -3801,17 +3887,8 @@ void SROA_Parameter_HLSL::RewriteBitcastWithIdenticalStructs(Function *F) {
     if (BitCastInst *BCI = dyn_cast<BitCastInst>(&*I)) {
       Type *DstTy = BCI->getDestTy();
       Type *SrcTy = BCI->getSrcTy();
-      if (!SrcTy->isPointerTy() || !DstTy->isPointerTy())
-        continue;
-      DstTy = DstTy->getPointerElementType();
-      SrcTy = SrcTy->getPointerElementType();
-      if (!SrcTy->isStructTy() || !DstTy->isStructTy())
-        continue;
-      StructType *DstST = cast<StructType>(DstTy);
-      StructType *SrcST = cast<StructType>(SrcTy);
-      if (!SrcST->isLayoutIdentical(DstST))
-        continue;
-      worklist.push_back(BCI);
+      if(ArePointersToStructsOfIdenticalLayouts(DstTy, SrcTy))
+        worklist.push_back(BCI);
     }
   }
 
@@ -3819,57 +3896,21 @@ void SROA_Parameter_HLSL::RewriteBitcastWithIdenticalStructs(Function *F) {
   while (!worklist.empty()) {
     BitCastInst *BCI = worklist.back();
     worklist.pop_back();
-    StructType *srcStTy = cast<StructType>(BCI->getSrcTy()->getPointerElementType());
-    StructType *destStTy = cast<StructType>(BCI->getDestTy()->getPointerElementType());
-    Value* srcPtr = BCI->getOperand(0);
-    IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(BCI->getParent()->getParent()));
-    AllocaInst *destPtr = AllocaBuilder.CreateAlloca(destStTy);
-    IRBuilder<> InstBuilder(BCI);
-    std::vector<unsigned> idxlist = { 0 };
-    CopyElementsOfStructsWithIdenticalLayout(InstBuilder, destPtr, srcPtr, srcStTy, idxlist);
-    BCI->replaceAllUsesWith(destPtr);
-    BCI->eraseFromParent();
-  }
-}
-
-std::vector<Value *>
-SROA_Parameter_HLSL::GetConstValueIdxList(IRBuilder<> &builder,
-                                          std::vector<unsigned> idxlist) {
-  std::vector<Value *> idxConstList;
-  for (unsigned idx : idxlist) {
-    idxConstList.push_back(ConstantInt::get(builder.getInt32Ty(), idx));
+    RewriteBitcastWithIdenticalStructs(BCI);
   }
-  return idxConstList;
 }
 
-void SROA_Parameter_HLSL::CopyElementsOfStructsWithIdenticalLayout(
-    IRBuilder<> &builder, Value *destPtr, Value *srcPtr, Type *ty,
-    std::vector<unsigned>& idxlist) {
-  if (ty->isStructTy()) {
-    for (unsigned i = 0; i < ty->getStructNumElements(); i++) {
-      idxlist.push_back(i);
-      CopyElementsOfStructsWithIdenticalLayout(
-          builder, destPtr, srcPtr, ty->getStructElementType(i), idxlist);
-      idxlist.pop_back();
-    }
-  } else if (ty->isArrayTy()) {
-    for (unsigned i = 0; i < ty->getArrayNumElements(); i++) {
-      idxlist.push_back(i);
-      CopyElementsOfStructsWithIdenticalLayout(
-          builder, destPtr, srcPtr, ty->getArrayElementType(), idxlist);
-      idxlist.pop_back();
-    }
-  } else if (ty->isIntegerTy() || ty->isFloatTy() || ty->isDoubleTy() ||
-             ty->isHalfTy() || ty->isVectorTy()) {
-    Value *srcGEP =
-        builder.CreateInBoundsGEP(srcPtr, GetConstValueIdxList(builder, idxlist));
-    Value *destGEP =
-        builder.CreateInBoundsGEP(destPtr, GetConstValueIdxList(builder, idxlist));
-    LoadInst *LI = builder.CreateLoad(srcGEP);
-    builder.CreateStore(LI, destGEP);
-  } else {
-    DXASSERT(0, "encountered unsupported type when copying elements of identical structs.");
-  }
+void SROA_Parameter_HLSL::RewriteBitcastWithIdenticalStructs(BitCastInst *BCI) {
+  StructType *srcStTy = cast<StructType>(BCI->getSrcTy()->getPointerElementType());
+  StructType *destStTy = cast<StructType>(BCI->getDestTy()->getPointerElementType());
+  Value* srcPtr = BCI->getOperand(0);
+  IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(BCI->getParent()->getParent()));
+  AllocaInst *destPtr = AllocaBuilder.CreateAlloca(destStTy);
+  IRBuilder<> InstBuilder(BCI);
+  std::vector<unsigned> idxlist = { 0 };
+  CopyElementsOfStructsWithIdenticalLayout(InstBuilder, destPtr, srcPtr, srcStTy, idxlist);
+  BCI->replaceAllUsesWith(destPtr);
+  BCI->eraseFromParent();
 }
 
 /// DeleteDeadInstructions - Erase instructions on the DeadInstrs list,

+ 22 - 0
tools/clang/test/CodeGenHLSL/batch/passes/sroa_hlsl/memcpy_types.hlsl

@@ -0,0 +1,22 @@
+// RUN: %dxc -E main -T vs_6_0 %s | FileCheck %s
+
+// Regression test a bug where passing array struct element to function causes
+// a crash due to memcpy src and dest type mismatch.
+
+// CHECK: @main
+ 
+struct my_struct
+{
+  float x : POSITION0;
+  float y : TEXCOORD0;
+};
+ 
+uint foo(my_struct s)
+{
+  return 0;
+}
+ 
+void main(my_struct a[1])
+{
+  uint r = foo(a[0]);
+}

+ 31 - 0
tools/clang/test/CodeGenHLSL/batch/passes/sroa_hlsl/memcpy_types_2.hlsl

@@ -0,0 +1,31 @@
+// RUN: %dxc -E main -T vs_6_0 %s | FileCheck %s
+
+// Regression test a bug where passing array struct element to function causes
+// a crash due to memcpy src and dest type mismatch.
+// This test makes sure that casting to a different type with the same layout still works.
+
+// CHECK: @main
+
+struct my_struct
+{
+  float x : POSITION0;
+  float y : TEXCOORD0;
+};
+
+struct my_struct_2
+{
+  float x;
+  float y;
+};
+ 
+uint foo(my_struct_2 s)
+{
+  return 0;
+}
+ 
+void main(my_struct a[1])
+{
+  uint r = foo((my_struct_2)a[0]);
+}
+
+

+ 30 - 0
tools/clang/test/CodeGenHLSL/batch/passes/sroa_hlsl/memcpy_types_mismatch.hlsl

@@ -0,0 +1,30 @@
+// RUN: %dxc -E main -T vs_6_0 %s | FileCheck %s
+
+// Test that casting here doesn't work.
+
+// CHECK-NOT: @main
+
+struct my_struct
+{
+  float x : POSITION0;
+  float y : TEXCOORD0;
+};
+
+struct my_struct_2
+{
+  float x;
+  float y;
+  float z;
+};
+ 
+uint foo(my_struct_2 s)
+{
+  return 0;
+}
+ 
+void main(my_struct a[1])
+{
+  uint r = foo(a[0]);
+}
+
+