Browse Source

Support case for memcpy float3[1], float3[1][1]. (#2280)

Xiang Li 6 years ago
parent
commit
365e6e8aa9

+ 42 - 20
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -3306,7 +3306,9 @@ static void CopyElementsOfStructsWithIdenticalLayout(
   }
 }
 
-static void ReplaceMemcpy(Value *V, Value *Src, MemCpyInst *MC) {
+static void ReplaceMemcpy(Value *V, Value *Src, MemCpyInst *MC,
+                          DxilFieldAnnotation *annotation,
+                          DxilTypeSystem &typeSys, const DataLayout &DL) {
   Type *TyV = V->getType()->getPointerElementType();
   Type *TySrc = Src->getType()->getPointerElementType();
   if (Constant *C = dyn_cast<Constant>(V)) {
@@ -3337,24 +3339,44 @@ static void ReplaceMemcpy(Value *V, Value *Src, MemCpyInst *MC) {
 
       BitCastInst *DestBCI = cast<BitCastInst>(DestVal);
       BitCastInst *SrcBCI = cast<BitCastInst>(SrcVal);
-      if (!ArePointersToStructsOfIdenticalLayouts(DestBCI->getSrcTy(), SrcBCI->getSrcTy())) {
-        DXASSERT(0, "Can't handle structs of different layouts");
-        return;
-      }
 
-      const DataLayout &DL = SrcBCI->getModule()->getDataLayout();
-      unsigned SrcSize = DL.getTypeAllocSize(SrcBCI->getOperand(0)->getType()->getPointerElementType());
-      unsigned MemcpySize = cast<ConstantInt>(MC->getLength())->getZExtValue() * MC->getAlignment();
-      if (SrcSize != MemcpySize) {
-        DXASSERT(0, "Cannot handle partial memcpy");
-        return;
-      }
+      Type* DstTy = DestBCI->getSrcTy();
+      Type *SrcTy = SrcBCI->getSrcTy();
+      if (ArePointersToStructsOfIdenticalLayouts(DstTy, SrcTy)) {
+        const DataLayout &DL = SrcBCI->getModule()->getDataLayout();
+        unsigned SrcSize = DL.getTypeAllocSize(
+            SrcBCI->getOperand(0)->getType()->getPointerElementType());
+        unsigned MemcpySize =
+            cast<ConstantInt>(MC->getLength())->getZExtValue() *
+            MC->getAlignment();
+        if (SrcSize != MemcpySize) {
+          DXASSERT(0, "Cannot handle partial memcpy");
+          return;
+        }
 
-      if (DestBCI->hasOneUse() && SrcBCI->hasOneUse()) {
-        IRBuilder<> Builder(MC);
-        StructType *srcStTy = cast<StructType>(SrcBCI->getOperand(0)->getType()->getPointerElementType());
-        std::vector<unsigned> idxlist = { 0 };
-        CopyElementsOfStructsWithIdenticalLayout(Builder, DestBCI->getOperand(0), SrcBCI->getOperand(0), srcStTy, idxlist);
+        if (DestBCI->hasOneUse() && SrcBCI->hasOneUse()) {
+          IRBuilder<> Builder(MC);
+          StructType *srcStTy = cast<StructType>(
+              SrcBCI->getOperand(0)->getType()->getPointerElementType());
+          std::vector<unsigned> idxlist = {0};
+          CopyElementsOfStructsWithIdenticalLayout(
+              Builder, DestBCI->getOperand(0), SrcBCI->getOperand(0), srcStTy,
+              idxlist);
+        }
+      } else {
+        if (DstTy == SrcTy) {
+          Value *DstPtr = DestBCI->getOperand(0);
+          Value *SrcPtr = SrcBCI->getOperand(0);
+          if (isa<GEPOperator>(DstPtr) || isa<GEPOperator>(SrcPtr)) {
+            MemcpySplitter::SplitMemCpy(MC, DL, annotation, typeSys);
+            return;
+          } else {
+            DstPtr->replaceAllUsesWith(SrcPtr);
+          }
+        } else {
+          DXASSERT(0, "Can't handle structs of different layouts");
+          return;
+        }
       }
     } else {
       DXASSERT(IsUnboundedArrayMemcpy(TyV, TySrc), "otherwise mismatched types in memcpy are not unbounded array");
@@ -3527,7 +3549,7 @@ bool SROA_Helper::LowerMemcpy(Value *V, DxilFieldAnnotation *annotation,
                   static_cast<HLSubscriptOpcode>(hlsl::GetHLOpcode(PtrCI));
               if (opcode == HLSubscriptOpcode::CBufferSubscript) {
                 // Ptr from CBuffer is safe.
-                ReplaceMemcpy(V, Src, MC);
+                ReplaceMemcpy(V, Src, MC, annotation, typeSys, DL);
                 return true;
               }
             }
@@ -3539,7 +3561,7 @@ bool SROA_Helper::LowerMemcpy(Value *V, DxilFieldAnnotation *annotation,
           PointerStatus SrcPS(size);
           PointerStatus::analyzePointer(Src, SrcPS, typeSys, bStructElt);
           if (SrcPS.storedType != PointerStatus::StoredType::Stored) {
-            ReplaceMemcpy(V, Src, MC);
+            ReplaceMemcpy(V, Src, MC, annotation, typeSys, DL);
             return true;
           }
         }
@@ -3562,7 +3584,7 @@ bool SROA_Helper::LowerMemcpy(Value *V, DxilFieldAnnotation *annotation,
           PointerStatus DestPS(size);
           PointerStatus::analyzePointer(Dest, DestPS, typeSys, bStructElt);
           if (DestPS.storedType != PointerStatus::StoredType::Stored) {
-            ReplaceMemcpy(Dest, V, MC);
+            ReplaceMemcpy(Dest, V, MC, annotation, typeSys, DL);
             // V still need to be flatten.
             // Lower memcpy come from Dest.
             return LowerMemcpy(V, annotation, typeSys, DL, bAllowReplace);

+ 32 - 0
tools/clang/test/CodeGenHLSL/batch/expressions/conversions_and_casts/array_copy.hlsl

@@ -0,0 +1,32 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// Make sure memcpy float3[1], float3[1][1] works.
+
+// CHECK:call %dx.types.CBufRet.f32 @dx.op.cbufferLoadLegacy.f32
+// CHECK:fadd
+
+struct A
+{
+  float3 a;
+  float3 b[1];
+};
+
+float doubleDim(A a[1]) {
+  return a[0].b[0].x + a[0].b[0].y;
+}
+
+float3 c;
+float3 d[1];
+
+A getA() {
+  A a;
+  a.a = c;
+  a.b = d;
+  return a;
+}
+
+float main() : SV_Target {
+  A a[1];
+  a[0] = getA();
+  return doubleDim(a);
+}