Parcourir la source

Use memcpy when cast cbuffer constant into static global. (#1125)

Xiang Li il y a 7 ans
Parent
commit
277bb0bce3

+ 3 - 2
include/dxc/HLSL/DxilConstants.h

@@ -56,8 +56,9 @@ namespace DXIL {
   const unsigned kMaxStructBufferStride = 2048;
   const unsigned kMaxHSOutputControlPointsTotalScalars = 3968;
   const unsigned kMaxHSOutputPatchConstantTotalScalars = 32*4;
-  const unsigned kMaxOutputTotalScalars = 32*4;
-  const unsigned kMaxInputTotalScalars = 32*4;
+  const unsigned kMaxSignatureTotalVectors = 32;
+  const unsigned kMaxOutputTotalScalars = kMaxSignatureTotalVectors * 4;
+  const unsigned kMaxInputTotalScalars = kMaxSignatureTotalVectors * 4;
   const unsigned kMaxClipOrCullDistanceElementCount = 2;
   const unsigned kMaxClipOrCullDistanceCount = 2 * 4;
   const unsigned kMaxGSOutputVertexCount = 1024;

+ 9 - 0
lib/HLSL/HLOperationLower.cpp

@@ -5249,6 +5249,15 @@ void TranslateCBAddressUserLegacy(Instruction *user, Value *handle,
 
     ldInst->replaceAllUsesWith(newLd);
     ldInst->eraseFromParent();
+  } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(user)) {
+    for (auto it = BCI->user_begin(); it != BCI->user_end(); ) {
+      Instruction *I = cast<Instruction>(*it++);
+      TranslateCBAddressUserLegacy(I,
+                                   handle, legacyIdx, channelOffset, hlslOP,
+                                   prevFieldAnnotation, dxilTypeSys,
+                                   DL, pObjHelper);
+    }
+    BCI->eraseFromParent();
   } else {
     // Must be GEP here
     GetElementPtrInst *GEP = cast<GetElementPtrInst>(user);

+ 11 - 5
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -3754,17 +3754,23 @@ static void ReplaceUnboundedArrayUses(Value *V, Value *Src, IRBuilder<> &Builder
 }
 
 static void ReplaceMemcpy(Value *V, Value *Src, MemCpyInst *MC) {
+  Type *TyV = V->getType()->getPointerElementType();
+  Type *TySrc = Src->getType()->getPointerElementType();
   if (Constant *C = dyn_cast<Constant>(V)) {
-    if (isa<Constant>(Src)) {
-      V->replaceAllUsesWith(Src);
+    if (TyV == TySrc) {
+      if (isa<Constant>(Src)) {
+        V->replaceAllUsesWith(Src);
+      } else {
+        // Replace Constant with a non-Constant.
+        IRBuilder<> Builder(MC);
+        ReplaceConstantWithInst(C, Src, Builder);
+      }
     } else {
-      // Replace Constant with a non-Constant.
       IRBuilder<> Builder(MC);
+      Src = Builder.CreateBitCast(Src, V->getType());
       ReplaceConstantWithInst(C, Src, Builder);
     }
   } else {
-    Type* TyV = V->getType()->getPointerElementType();
-    Type* TySrc = Src->getType()->getPointerElementType();
     if (TyV == TySrc) {
       if (V != Src)
         V->replaceAllUsesWith(Src);

+ 16 - 3
tools/clang/lib/CodeGen/CGExprAgg.cpp

@@ -726,9 +726,22 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
       Expr *Src = E->getSubExpr();
       switch (CGF.getEvaluationKind(Ty)) {
       case TEK_Aggregate: {
-        LValue LV = CGF.EmitAggExprToLValue(Src);
-        CGF.CGM.getHLSLRuntime().EmitHLSLFlatConversionAggregateCopy(
-            CGF, LV.getAddress(), Src->getType(), DestPtr, E->getType());
+        if (CastExpr *SrcCast = dyn_cast<CastExpr>(Src)) {
+          if (SrcCast->getCastKind() == CK_LValueToRValue) {
+            // Skip the lval to rval cast to reach decl.
+            Src = SrcCast->getSubExpr();
+          }
+        }
+        // Just use decl if possible to skip useless copy.
+        if (DeclRefExpr *SrcDecl = dyn_cast<DeclRefExpr>(Src)) {
+          LValue LV = CGF.EmitLValue(SrcDecl);
+          CGF.CGM.getHLSLRuntime().EmitHLSLFlatConversionAggregateCopy(
+              CGF, LV.getAddress(), Src->getType(), DestPtr, E->getType());
+        } else {
+          LValue LV = CGF.EmitAggExprToLValue(Src);
+          CGF.CGM.getHLSLRuntime().EmitHLSLFlatConversionAggregateCopy(
+              CGF, LV.getAddress(), Src->getType(), DestPtr, E->getType());
+        }
       } break;
       case TEK_Scalar: {
         llvm::Value *SrcVal = CGF.EmitScalarExpr(Src);

+ 50 - 0
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -2394,6 +2394,11 @@ void CGMSHLSLRuntime::AddConstant(VarDecl *constDecl, HLCBuffer &CB) {
     // For static inside cbuffer, take as global static.
     // Don't add to cbuffer.
     CGM.EmitGlobal(constDecl);
+    // Add type annotation for static global types.
+    // May need it when cast from cbuf.
+    DxilTypeSystem &dxilTypeSys = m_pHLModule->GetTypeSystem();
+    unsigned arraySize = 0;
+    AddTypeAnnotation(constDecl->getType(), dxilTypeSys, arraySize);
     return;
   }
   // Search defined structure for resource objects and fail
@@ -5975,6 +5980,43 @@ void CGMSHLSLRuntime::EmitHLSLAggregateCopy(CodeGenFunction &CGF, llvm::Value *S
     SmallVector<Value *, 4> idxList;
     EmitHLSLAggregateCopy(CGF, SrcPtr, DestPtr, idxList, Ty, Ty, SrcPtr->getType());
 }
+// To memcpy, need element type match.
+// For struct type, the layout should match in cbuffer layout.
+// struct { float2 x; float3 y; } will not match struct { float3 x; float2 y; }.
+// struct { float2 x; float3 y; } will not match array of float.
+static bool IsTypeMatchForMemcpy(llvm::Type *SrcTy, llvm::Type *DestTy) {
+  llvm::Type *SrcEltTy = dxilutil::GetArrayEltTy(SrcTy);
+  llvm::Type *DestEltTy = dxilutil::GetArrayEltTy(DestTy);
+  if (SrcEltTy == DestEltTy)
+    return true;
+
+  llvm::StructType *SrcST = dyn_cast<llvm::StructType>(SrcEltTy);
+  llvm::StructType *DestST = dyn_cast<llvm::StructType>(DestEltTy);
+  if (SrcST && DestST) {
+    // Only allow identical struct.
+    return SrcST->isLayoutIdentical(DestST);
+  } else if (!SrcST && !DestST) {
+    // For basic type, if one is array, one is not array, layout is different.
+    // If both array, type mismatch. If both basic, copy should be fine.
+    // So all return false.
+    return false;
+  } else {
+    // One struct, one basic type.
+    // Make sure all struct element match the basic type and basic type is
+    // vector4.
+    llvm::StructType *ST = SrcST ? SrcST : DestST;
+    llvm::Type *Ty = SrcST ? DestEltTy : SrcEltTy;
+    if (!Ty->isVectorTy())
+      return false;
+    if (Ty->getVectorNumElements() != 4)
+      return false;
+    for (llvm::Type *EltTy : ST->elements()) {
+      if (EltTy != Ty)
+        return false;
+    }
+    return true;
+  }
+}
 
 void CGMSHLSLRuntime::EmitHLSLFlatConversionAggregateCopy(CodeGenFunction &CGF, llvm::Value *SrcPtr,
     clang::QualType SrcTy,
@@ -5993,6 +6035,14 @@ void CGMSHLSLRuntime::EmitHLSLFlatConversionAggregateCopy(CodeGenFunction &CGF,
     unsigned sizeDest = TheModule.getDataLayout().getTypeAllocSize(DestPtrTy);
     CGF.Builder.CreateMemCpy(DestPtr, SrcPtr, std::max(sizeSrc, sizeDest), 1);
     return;
+  } else if (GlobalVariable *GV = dyn_cast<GlobalVariable>(DestPtr)) {
+    if (GV->isInternalLinkage(GV->getLinkage()) &&
+        IsTypeMatchForMemcpy(SrcPtrTy, DestPtrTy)) {
+      unsigned sizeSrc = TheModule.getDataLayout().getTypeAllocSize(SrcPtrTy);
+      unsigned sizeDest = TheModule.getDataLayout().getTypeAllocSize(DestPtrTy);
+      CGF.Builder.CreateMemCpy(DestPtr, SrcPtr, std::min(sizeSrc, sizeDest), 1);
+      return;
+    }
   }
 
   // It is possiable to implement EmitHLSLAggregateCopy, EmitHLSLAggregateStore

+ 27 - 0
tools/clang/test/CodeGenHLSL/quick-test/constant_cast.hlsl

@@ -0,0 +1,27 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+
+
+// Make sure no store is generated.
+// CHECK-NOT:store {{.*}},
+
+struct ST
+{
+    float4 a;
+    float4 b;
+    float4 c;
+};
+
+
+cbuffer cbModelSkinningConstants : register ( b4 )
+{
+    float4 v[ 2 * 256 * 3 ];
+
+    static const float4 v2d[ 512 ] [ 3 ] = v ;
+    static const ST vst[ 512 ] = v;
+} ;
+
+
+float4 main(int i:I) : SV_Target {
+  return v2d[i][1] + vst[i].b;
+}