ソースを参照

Avoid cbuffer copy for special pattern. (#3483)

* Avoid cbuffer copy for special pattern.

float4 cb[16];
float v[64] = cb;
Xiang Li 4 年 前
コミット
6f9a803bbe

+ 9 - 3
lib/HLSL/HLOperationLower.cpp

@@ -6638,15 +6638,21 @@ void TranslateCBGepLegacy(GetElementPtrInst *GEP, Value *handle,
         }
       } else {
         Type *EltTy = GEPIt->getVectorElementType();
+        unsigned size = DL.getTypeSizeInBits(EltTy);
+        unsigned vecSize = 4;
+        if (size == 64)
+          vecSize = 2;
+        else if (size == 16)
+          vecSize = 8;
         // Load the whole register.
         Value *newLd = GenerateCBLoadLegacy(handle, legacyIndex,
                                      /*channelOffset*/ 0, EltTy,
-                                     /*vecSize*/ 4, hlslOP, Builder);
+                                     /*vecSize*/ vecSize, hlslOP, Builder);
         // Copy to array.
         IRBuilder<> AllocaBuilder(GEP->getParent()->getParent()->getEntryBlock().getFirstInsertionPt());
-        Value *tempArray = AllocaBuilder.CreateAlloca(ArrayType::get(EltTy, 4));
+        Value *tempArray = AllocaBuilder.CreateAlloca(ArrayType::get(EltTy, vecSize));
         Value *zeroIdx = hlslOP->GetU32Const(0);
-        for (unsigned i = 0; i < 4; i++) {
+        for (unsigned i = 0; i < vecSize; i++) {
           Value *Elt = Builder.CreateExtractElement(newLd, i);
           Value *EltGEP = Builder.CreateInBoundsGEP(tempArray, {zeroIdx, hlslOP->GetU32Const(i)});
           Builder.CreateStore(Elt, EltGEP);

+ 185 - 3
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -1007,6 +1007,79 @@ DxilFieldAnnotation *FindAnnotationFromMatUser(Value *Mat,
   return nullptr;
 }
 
+namespace {
+bool isCBVec4ArrayToScalarArray(Type *TyV, Value *Src, Type *TySrc, const DataLayout &DL) {
+  Value *SrcPtr = Src;
+  while (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(SrcPtr)) {
+    SrcPtr = GEP->getPointerOperand();
+  }
+  CallInst *CI = dyn_cast<CallInst>(SrcPtr);
+  if (!CI)
+    return false;
+
+  Function *F = CI->getCalledFunction();
+  if (hlsl::GetHLOpcodeGroupByName(F) != HLOpcodeGroup::HLSubscript)
+    return false;
+
+  if (hlsl::GetHLOpcode(CI) != (unsigned)HLSubscriptOpcode::CBufferSubscript)
+    return false;
+
+  ArrayType *AT = dyn_cast<ArrayType>(TySrc);
+  if (!AT)
+    return false;
+  VectorType *VT = dyn_cast<VectorType>(AT->getElementType());
+
+  if (!VT)
+    return false;
+
+  if (DL.getTypeSizeInBits(VT) != 128)
+    return false;
+
+  ArrayType *DstAT = dyn_cast<ArrayType>(TyV);
+  if (!DstAT)
+    return false;
+
+  if (VT->getElementType() != DstAT->getElementType())
+    return false;
+
+  unsigned sizeInBits = DL.getTypeSizeInBits(VT->getElementType());
+  if (sizeInBits < 32)
+    return false;
+  return true;
+}
+
+bool trySplitCBVec4ArrayToScalarArray(Value *Dest, Type *TyV, Value *Src,
+                                      Type *TySrc, const DataLayout &DL,
+                                      IRBuilder<> &B) {
+  if (!isCBVec4ArrayToScalarArray(TyV, Src, TySrc, DL))
+    return false;
+
+  ArrayType *AT = cast<ArrayType>(TyV);
+  Type *EltTy = AT->getElementType();
+  unsigned sizeInBits = DL.getTypeSizeInBits(EltTy);
+  unsigned vecSize = 4;
+  if (sizeInBits == 64)
+    vecSize = 2;
+  unsigned arraySize = AT->getNumElements();
+  unsigned vecArraySize = arraySize / vecSize;
+  Value *zeroIdx = B.getInt32(0);
+  for (unsigned a = 0; a < vecArraySize; a++) {
+    Value *SrcGEP = B.CreateGEP(Src, {zeroIdx, B.getInt32(a)});
+    Value *Ld = B.CreateLoad(SrcGEP);
+    for (unsigned v = 0; v < vecSize; v++) {
+      Value *Elt = B.CreateExtractElement(Ld, v);
+
+      Value *DestGEP =
+          B.CreateGEP(Dest, {zeroIdx, B.getInt32(a * vecSize + v)});
+      B.CreateStore(Elt, DestGEP);
+    }
+  }
+
+  return true;
+}
+
+}
+
 void MemcpySplitter::SplitMemCpy(MemCpyInst *MI, const DataLayout &DL,
                                  DxilFieldAnnotation *fieldAnnotation,
                                  DxilTypeSystem &typeSys,
@@ -1031,6 +1104,11 @@ void MemcpySplitter::SplitMemCpy(MemCpyInst *MI, const DataLayout &DL,
 
   // Allow copy between different address space.
   if (DestTy != SrcTy) {
+    if (trySplitCBVec4ArrayToScalarArray(Dest, DestTy, Src, SrcTy, DL,
+                                         Builder)) {
+      // delete memcpy
+      DeleteMemcpy(MI);
+    }
     return;
   }
   // Try to find fieldAnnotation from user of Dest/Src.
@@ -3254,6 +3332,106 @@ static void updateLifetimeForReplacement(Value *From, Value *To)
 
 static bool DominateAllUsers(Instruction *I, Value *V, DominatorTree *DT);
 
+namespace {
+void replaceScalarArrayGEPWithVectorArrayGEP(User *GEP, Value *VectorArray,
+                                             IRBuilder<> &Builder,
+                                             unsigned sizeInDwords) {
+  gep_type_iterator GEPIt = gep_type_begin(GEP), E = gep_type_end(GEP);
+
+  Value *PtrOffset = GEPIt.getOperand();
+  ++GEPIt;
+  Value *ArrayIdx = GEPIt.getOperand();
+  ++GEPIt;
+  ArrayIdx = Builder.CreateAdd(PtrOffset, ArrayIdx);
+  DXASSERT_LOCALVAR(E, GEPIt == E, "invalid gep on scalar array");
+
+  unsigned shift = 2;
+  unsigned mask = 0x3;
+  switch (sizeInDwords) {
+  case 2:
+    shift = 1;
+    mask = 1;
+    break;
+  case 1:
+    shift = 2;
+    mask = 0x3;
+    break;
+  default:
+    DXASSERT(0, "invalid scalar size");
+    break;
+  }
+
+  Value *VecIdx = Builder.CreateLShr(ArrayIdx, shift);
+  Value *VecPtr = Builder.CreateGEP(
+      VectorArray, {ConstantInt::get(VecIdx->getType(), 0), VecIdx});
+  Value *CompIdx = Builder.CreateAnd(ArrayIdx, mask);
+  Value *NewGEP = Builder.CreateGEP(
+      VecPtr, {ConstantInt::get(CompIdx->getType(), 0), CompIdx});
+  GEP->replaceAllUsesWith(NewGEP);
+}
+
+void replaceScalarArrayWithVectorArray(Value *ScalarArray, Value *VectorArray,
+                                       MemCpyInst *MC, unsigned sizeInDwords) {
+  LLVMContext &Context = ScalarArray->getContext();
+  // All users should be element type.
+  // Replace users of AI or GV.
+  for (auto it = ScalarArray->user_begin(); it != ScalarArray->user_end();) {
+    User *U = *(it++);
+    if (U->user_empty())
+      continue;
+    if (BitCastInst *BCI = dyn_cast<BitCastInst>(U)) {
+      BCI->setOperand(0, VectorArray);
+      continue;
+    }
+
+    if (ConstantExpr *CE = dyn_cast<ConstantExpr>(U)) {
+      IRBuilder<> Builder(Context);
+      if (GEPOperator *GEP = dyn_cast<GEPOperator>(U)) {
+        // NewGEP must be GEPOperator too.
+        // No instruction will be build.
+        replaceScalarArrayGEPWithVectorArrayGEP(U, VectorArray, Builder,
+                                                sizeInDwords);
+      } else if (CE->getOpcode() == Instruction::AddrSpaceCast) {
+        Value *NewAddrSpaceCast = Builder.CreateAddrSpaceCast(
+            VectorArray,
+            PointerType::get(VectorArray->getType()->getPointerElementType(),
+                             CE->getType()->getPointerAddressSpace()));
+        replaceScalarArrayWithVectorArray(CE, NewAddrSpaceCast, MC,
+                                          sizeInDwords);
+      } else if (CE->hasOneUse() && CE->user_back() == MC) {
+        continue;
+      } else {
+        DXASSERT(0, "not implemented");
+      }
+    } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
+      IRBuilder<> Builder(GEP);
+      replaceScalarArrayGEPWithVectorArrayGEP(U, VectorArray, Builder,
+                                              sizeInDwords);
+      GEP->eraseFromParent();
+    } else {
+      DXASSERT(0, "not implemented");
+    }
+  }
+}
+
+// For pattern like
+// float4 cb[16];
+// float v[64] = cb;
+bool tryToReplaceCBVec4ArrayToScalarArray(Value *V, Type *TyV, Value *Src,
+                                          Type *TySrc, MemCpyInst *MC,
+                                          const DataLayout &DL) {
+  if (!isCBVec4ArrayToScalarArray(TyV, Src, TySrc, DL))
+    return false;
+
+  ArrayType *AT = cast<ArrayType>(TyV);
+  Type *EltTy = AT->getElementType();
+  unsigned sizeInBits = DL.getTypeSizeInBits(EltTy);
+  // Convert array of float4 to array of float.
+  replaceScalarArrayWithVectorArray(V, Src, MC, sizeInBits >> 5);
+  return true;
+}
+
+} // namespace
 
 static bool ReplaceMemcpy(Value *V, Value *Src, MemCpyInst *MC,
                           DxilFieldAnnotation *annotation, DxilTypeSystem &typeSys,
@@ -3285,9 +3463,13 @@ static bool ReplaceMemcpy(Value *V, Value *Src, MemCpyInst *MC,
         ReplaceConstantWithInst(C, Src, Builder);
       }
     } else {
-      IRBuilder<> Builder(MC);
-      Src = Builder.CreateBitCast(Src, V->getType());
-      ReplaceConstantWithInst(C, Src, Builder);
+      // Try convert special pattern for cbuffer which copy array of float4 to
+      // array of float.
+      if (!tryToReplaceCBVec4ArrayToScalarArray(V, TyV, Src, TySrc, MC, DL)) {
+        IRBuilder<> Builder(MC);
+        Src = Builder.CreateBitCast(Src, V->getType());
+        ReplaceConstantWithInst(C, Src, Builder);
+      }
     }
   } else {
     if (TyV == TySrc) {

+ 38 - 7
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -5306,13 +5306,33 @@ static bool IsTypeMatchForMemcpy(llvm::Type *SrcTy, llvm::Type *DestTy) {
   }
 }
 
+static bool IsVec4ArrayToScalarArrayForMemcpy(llvm::Type *SrcTy, llvm::Type *DestTy, const DataLayout &DL) {
+  if (!SrcTy->isArrayTy())
+    return false;
+  llvm::Type *SrcEltTy = dxilutil::GetArrayEltTy(SrcTy);
+  llvm::Type *DestEltTy = dxilutil::GetArrayEltTy(DestTy);
+  if (SrcEltTy == DestEltTy)
+    return true;
+  llvm::VectorType *VT  = dyn_cast<llvm::VectorType>(SrcEltTy);
+  if (!VT)
+    return false;
+
+  if (DL.getTypeSizeInBits(VT) != 128)
+    return false;
+
+  if (DL.getTypeSizeInBits(DestEltTy) < 32)
+    return false;
+
+  return VT->getElementType() == DestEltTy;
+}
+
 void CGMSHLSLRuntime::EmitHLSLFlatConversionAggregateCopy(CodeGenFunction &CGF, llvm::Value *SrcPtr,
     clang::QualType SrcTy,
     llvm::Value *DestPtr,
     clang::QualType DestTy) {
   llvm::Type *SrcPtrTy = SrcPtr->getType()->getPointerElementType();
   llvm::Type *DestPtrTy = DestPtr->getType()->getPointerElementType();
-
+  const DataLayout &DL = TheModule.getDataLayout();
   bool bDefaultRowMajor = m_pHLModule->GetHLOptions().bDefaultRowMajor;
   if (SrcPtrTy == DestPtrTy) {
     bool bMatArrayRotate = false;
@@ -5326,7 +5346,7 @@ void CGMSHLSLRuntime::EmitHLSLFlatConversionAggregateCopy(CodeGenFunction &CGF,
     }
     if (!bMatArrayRotate) {
       // Memcpy if type is match.
-      unsigned size = TheModule.getDataLayout().getTypeAllocSize(SrcPtrTy);
+      unsigned size = DL.getTypeAllocSize(SrcPtrTy);
       CGF.Builder.CreateMemCpy(DestPtr, SrcPtr, size, 1);
       return;
     }
@@ -5353,24 +5373,35 @@ void CGMSHLSLRuntime::EmitHLSLFlatConversionAggregateCopy(CodeGenFunction &CGF,
       Value *Cast = CGF.Builder.CreateBitCast(
           SrcPtr,
           ResultTy->getPointerTo(DestPtr->getType()->getPointerAddressSpace()));
-      unsigned size = TheModule.getDataLayout().getTypeAllocSize(
+      unsigned size = DL.getTypeAllocSize(
           DestPtrTy);
       CGF.Builder.CreateMemCpy(DestPtr, Cast, size, 1);
       return;
     }
   } else if (dxilutil::IsHLSLObjectType(dxilutil::GetArrayEltTy(SrcPtrTy)) &&
              dxilutil::IsHLSLObjectType(dxilutil::GetArrayEltTy(DestPtrTy))) {
-    unsigned sizeSrc = TheModule.getDataLayout().getTypeAllocSize(SrcPtrTy);
-    unsigned sizeDest = TheModule.getDataLayout().getTypeAllocSize(DestPtrTy);
+    unsigned sizeSrc = DL.getTypeAllocSize(SrcPtrTy);
+    unsigned sizeDest = DL.getTypeAllocSize(DestPtrTy);
     CGF.Builder.CreateMemCpy(DestPtr, SrcPtr, std::max(sizeSrc, sizeDest), 1);
     return;
   } else if (GlobalVariable *GV = dyn_cast<GlobalVariable>(DestPtr)) {
     if (GV->isInternalLinkage(GV->getLinkage()) &&
         IsTypeMatchForMemcpy(SrcPtrTy, DestPtrTy)) {
-      unsigned sizeSrc = TheModule.getDataLayout().getTypeAllocSize(SrcPtrTy);
-      unsigned sizeDest = TheModule.getDataLayout().getTypeAllocSize(DestPtrTy);
+      unsigned sizeSrc = DL.getTypeAllocSize(SrcPtrTy);
+      unsigned sizeDest = DL.getTypeAllocSize(DestPtrTy);
       CGF.Builder.CreateMemCpy(DestPtr, SrcPtr, std::min(sizeSrc, sizeDest), 1);
       return;
+    } else if (GlobalVariable *SrcGV = dyn_cast<GlobalVariable>(SrcPtr)) {
+      if (GV->isInternalLinkage(GV->getLinkage()) &&
+          m_ConstVarAnnotationMap.count(SrcGV) &&
+          IsVec4ArrayToScalarArrayForMemcpy(SrcPtrTy, DestPtrTy, DL)) {
+        unsigned sizeSrc = DL.getTypeAllocSize(SrcPtrTy);
+        unsigned sizeDest = DL.getTypeAllocSize(DestPtrTy);
+        if (sizeSrc == sizeDest) {
+          CGF.Builder.CreateMemCpy(DestPtr, SrcPtr, sizeSrc, 1);
+          return;
+        }
+      }
     }
   }
 

+ 19 - 0
tools/clang/test/HLSLFileCheck/hlsl/objects/CbufferLegacy/cb_array_mutate.hlsl

@@ -0,0 +1,19 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// CHECK:%[[ID:[0-9]+]] = call i32 @dx.op.loadInput
+// CHECK:lshr i32 %[[ID]], 2
+// CHECK:and i32 %[[ID]], 3
+// Make sure only 1 cb load.
+// CHECK:call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32
+// CHECK-NOT:call %dx.types.CBufRet
+
+cbuffer Pack
+{
+    int4 __packed[16];
+};
+
+static int arrayReallyWant[64] = (int[64])__packed;
+
+float main(int i:I) : SV_Target {
+  return arrayReallyWant[i];
+}

+ 19 - 0
tools/clang/test/HLSLFileCheck/hlsl/objects/CbufferLegacy/cb_array_mutate2.hlsl

@@ -0,0 +1,19 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// CHECK:%[[ID:[0-9]+]] = call i32 @dx.op.loadInput
+// CHECK:lshr i32 %[[ID]], 1
+// CHECK:and i32 %[[ID]], 1
+// Make sure only 1 cb load.
+// CHECK:call %dx.types.CBufRet.i64 @dx.op.cbufferLoadLegacy.i64
+// CHECK-NOT:call %dx.types.CBufRet
+
+cbuffer Pack
+{
+    int64_t2 __packed[16];
+};
+
+static int64_t arrayReallyWant[32] = (int64_t[32])__packed;
+
+float main(int i:I) : SV_Target {
+  return arrayReallyWant[i];
+}

+ 20 - 0
tools/clang/test/HLSLFileCheck/hlsl/objects/CbufferLegacy/cb_array_mutate3.hlsl

@@ -0,0 +1,20 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// Make sure have 4 cb load
+// CHECK:call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32
+// CHECK:call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32
+// CHECK:call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32
+// CHECK:call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32
+// CHECK-NOT:call %dx.types.CBufRet
+
+cbuffer Pack
+{
+    int4 __packed[4];
+};
+
+static int arrayReallyWant[16] = (int[16])__packed;
+
+float main(int i:I) : SV_Target {
+  arrayReallyWant[0] = 3;
+  return arrayReallyWant[i];
+}