2
0
Эх сурвалжийг харах

Add limited support for unbounded resource arrays in parameter list (#399)

- Fix IsHLSLResouceType typo
Tex Riddell 8 жил өмнө
parent
commit
11dd70c410

+ 3 - 1
lib/HLSL/DxilGenerationPass.cpp

@@ -1879,8 +1879,10 @@ void DxilLegalizeResourceUsePass::PromoteLocalResource(Function &F) {
     // the entry node
     // the entry node
     for (BasicBlock::iterator I = BB.begin(), E = --BB.end(); I != E; ++I)
     for (BasicBlock::iterator I = BB.begin(), E = --BB.end(); I != E; ++I)
       if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) { // Is it an alloca?
       if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) { // Is it an alloca?
-        if (HandleTy == HLModule::GetArrayEltTy(AI->getAllocatedType()))
+        if (HandleTy == HLModule::GetArrayEltTy(AI->getAllocatedType())) {
+          DXASSERT(isAllocaPromotable(AI), "otherwise, non-promotable resource array alloca found");
           Allocas.push_back(AI);
           Allocas.push_back(AI);
+        }
       }
       }
     if (Allocas.empty())
     if (Allocas.empty())
       break;
       break;

+ 29 - 2
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -3554,7 +3554,8 @@ void PointerStatus::analyzePointer(const Value *V, PointerStatus &PS,
         PS.memcpySet.insert(MI);
         PS.memcpySet.insert(MI);
         bool bFullCopy = false;
         bool bFullCopy = false;
         if (ConstantInt *Length = dyn_cast<ConstantInt>(MC->getLength())) {
         if (ConstantInt *Length = dyn_cast<ConstantInt>(MC->getLength())) {
-          bFullCopy = PS.Size == Length->getLimitedValue();
+          bFullCopy = PS.Size == Length->getLimitedValue()
+            || PS.Size == 0 || Length->getLimitedValue() == 0;  // handle unbounded arrays
         }
         }
         if (MC->getRawDest() == V) {
         if (MC->getRawDest() == V) {
           if (bFullCopy &&
           if (bFullCopy &&
@@ -3689,6 +3690,21 @@ static void ReplaceConstantWithInst(Constant *C, Value *V, IRBuilder<> &Builder)
   }
   }
 }
 }
 
 
+static void ReplaceUnboundedArrayUses(Value *V, Value *Src, IRBuilder<> &Builder) {
+  for (auto it = V->user_begin(); it != V->user_end(); ) {
+    User *U = *(it++);
+    if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
+      SmallVector<Value *, 4> idxList(GEP->idx_begin(), GEP->idx_end());
+      Value *NewGEP = Builder.CreateGEP(Src, idxList);
+      GEP->replaceAllUsesWith(NewGEP);
+    } else if (BitCastInst *BC = dyn_cast<BitCastInst>(U)) {
+      BC->setOperand(0, Src);
+    } else {
+      DXASSERT(false, "otherwise unbounded array used in unexpected instruction");
+    }
+  }
+}
+
 static void ReplaceMemcpy(Value *V, Value *Src, MemCpyInst *MC) {
 static void ReplaceMemcpy(Value *V, Value *Src, MemCpyInst *MC) {
   if (Constant *C = dyn_cast<Constant>(V)) {
   if (Constant *C = dyn_cast<Constant>(V)) {
     if (isa<Constant>(Src)) {
     if (isa<Constant>(Src)) {
@@ -3699,7 +3715,18 @@ static void ReplaceMemcpy(Value *V, Value *Src, MemCpyInst *MC) {
       ReplaceConstantWithInst(C, Src, Builder);
       ReplaceConstantWithInst(C, Src, Builder);
     }
     }
   } else {
   } else {
-    V->replaceAllUsesWith(Src);
+    Type* TyV = V->getType()->getPointerElementType();
+    Type* TySrc = Src->getType()->getPointerElementType();
+    if (TyV == TySrc) {
+      V->replaceAllUsesWith(Src);
+    } else {
+      DXASSERT((TyV->isArrayTy() && TySrc->isArrayTy()) &&
+               (TyV->getArrayNumElements() == 0 ||
+                TySrc->getArrayNumElements() == 0),
+               "otherwise mismatched types in memcpy are not unbounded array");
+      IRBuilder<> Builder(MC);
+      ReplaceUnboundedArrayUses(V, Src, Builder);
+    }
   }
   }
   Value *RawDest = MC->getOperand(0);
   Value *RawDest = MC->getOperand(0);
   Value *RawSrc = MC->getOperand(1);
   Value *RawSrc = MC->getOperand(1);

+ 1 - 1
tools/clang/include/clang/AST/HlslTypes.h

@@ -360,7 +360,7 @@ bool IsHLSLPointStreamType(clang::QualType type);
 bool IsHLSLLineStreamType(clang::QualType type);
 bool IsHLSLLineStreamType(clang::QualType type);
 bool IsHLSLTriangleStreamType(clang::QualType type);
 bool IsHLSLTriangleStreamType(clang::QualType type);
 bool IsHLSLStreamOutputType(clang::QualType type);
 bool IsHLSLStreamOutputType(clang::QualType type);
-bool IsHLSLResouceType(clang::QualType type);
+bool IsHLSLResourceType(clang::QualType type);
 clang::QualType GetHLSLResourceResultType(clang::QualType type);
 clang::QualType GetHLSLResourceResultType(clang::QualType type);
 bool IsIncompleteHLSLResourceArrayType(clang::ASTContext& context, clang::QualType type);
 bool IsIncompleteHLSLResourceArrayType(clang::ASTContext& context, clang::QualType type);
 clang::QualType GetHLSLInputPatchElementType(clang::QualType type);
 clang::QualType GetHLSLInputPatchElementType(clang::QualType type);

+ 2 - 2
tools/clang/lib/AST/HlslTypes.cpp

@@ -339,7 +339,7 @@ bool IsHLSLStreamOutputType(QualType type) {
   }
   }
   return false;
   return false;
 }
 }
-bool IsHLSLResouceType(clang::QualType type) {
+bool IsHLSLResourceType(clang::QualType type) {
   if (const RecordType *RT = type->getAs<RecordType>()) {
   if (const RecordType *RT = type->getAs<RecordType>()) {
     StringRef name = RT->getDecl()->getName();
     StringRef name = RT->getDecl()->getName();
     if (name == "Texture1D" || name == "RWTexture1D")
     if (name == "Texture1D" || name == "RWTexture1D")
@@ -402,7 +402,7 @@ bool IsIncompleteHLSLResourceArrayType(clang::ASTContext &context,
   if (type->isIncompleteArrayType()) {
   if (type->isIncompleteArrayType()) {
     const IncompleteArrayType *IAT = context.getAsIncompleteArrayType(type);
     const IncompleteArrayType *IAT = context.getAsIncompleteArrayType(type);
     QualType EltTy = IAT->getElementType();
     QualType EltTy = IAT->getElementType();
-    if (IsHLSLResouceType(EltTy))
+    if (IsHLSLResourceType(EltTy))
       return true;
       return true;
   }
   }
   return false;
   return false;

+ 8 - 1
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -897,7 +897,7 @@ unsigned CGMSHLSLRuntime::AddTypeAnnotation(QualType Ty,
     DxilStructAnnotation *annotation = dxilTypeSys.AddStructAnnotation(ST);
     DxilStructAnnotation *annotation = dxilTypeSys.AddStructAnnotation(ST);
 
 
     return ConstructStructAnnotation(annotation, RD, dxilTypeSys);
     return ConstructStructAnnotation(annotation, RD, dxilTypeSys);
-  } else if (IsHLSLResouceType(Ty)) {
+  } else if (IsHLSLResourceType(Ty)) {
     // Save result type info.
     // Save result type info.
     AddTypeAnnotation(GetHLSLResourceResultType(Ty), dxilTypeSys, arrayEltSize);
     AddTypeAnnotation(GetHLSLResourceResultType(Ty), dxilTypeSys, arrayEltSize);
     // Resource don't count for cbuffer size.
     // Resource don't count for cbuffer size.
@@ -5528,7 +5528,14 @@ void CGMSHLSLRuntime::EmitHLSLFlatConversionAggregateCopy(CodeGenFunction &CGF,
     unsigned size = TheModule.getDataLayout().getTypeAllocSize(SrcPtrTy);
     unsigned size = TheModule.getDataLayout().getTypeAllocSize(SrcPtrTy);
     CGF.Builder.CreateMemCpy(DestPtr, SrcPtr, size, 1);
     CGF.Builder.CreateMemCpy(DestPtr, SrcPtr, size, 1);
     return;
     return;
+  } else if (HLModule::IsHLSLObjectType(HLModule::GetArrayEltTy(SrcPtrTy)) &&
+             HLModule::IsHLSLObjectType(HLModule::GetArrayEltTy(DestPtrTy))) {
+    unsigned sizeSrc = TheModule.getDataLayout().getTypeAllocSize(SrcPtrTy);
+    unsigned sizeDest = TheModule.getDataLayout().getTypeAllocSize(DestPtrTy);
+    CGF.Builder.CreateMemCpy(DestPtr, SrcPtr, std::max(sizeSrc, sizeDest), 1);
+    return;
   }
   }
+
   // It is possiable to implement EmitHLSLAggregateCopy, EmitHLSLAggregateStore
   // It is possiable to implement EmitHLSLAggregateCopy, EmitHLSLAggregateStore
   // the same way. But split value to scalar will generate many instruction when
   // the same way. But split value to scalar will generate many instruction when
   // src type is same as dest type.
   // src type is same as dest type.

+ 1 - 0
tools/clang/lib/Sema/SemaChecking.cpp

@@ -8230,6 +8230,7 @@ bool Sema::CheckParmsForFunctionDef(ParmVarDecl *const *P,
     //
     //
     // This is also C++ [dcl.fct]p6.
     // This is also C++ [dcl.fct]p6.
     if (!Param->isInvalidDecl() &&
     if (!Param->isInvalidDecl() &&
+        !(getLangOpts().HLSL && Param->getType()->isIncompleteArrayType()) &&  // HLSL Change: allow incomplete array type
         RequireCompleteType(Param->getLocation(), Param->getType(),
         RequireCompleteType(Param->getLocation(), Param->getType(),
                             diag::err_typecheck_decl_incomplete_type)) {
                             diag::err_typecheck_decl_incomplete_type)) {
       Param->setInvalidDecl();
       Param->setInvalidDecl();

+ 2 - 1
tools/clang/lib/Sema/SemaExpr.cpp

@@ -4552,7 +4552,8 @@ bool Sema::GatherArgumentsForCall(SourceLocation CallLoc, FunctionDecl *FDecl,
     if (ArgIx < Args.size()) {
     if (ArgIx < Args.size()) {
       Arg = Args[ArgIx++];
       Arg = Args[ArgIx++];
 
 
-      if (RequireCompleteType(Arg->getLocStart(),
+      if (!(getLangOpts().HLSL && ProtoArgType->isIncompleteArrayType()) && // HLSL Change: allow incomplete array
+          RequireCompleteType(Arg->getLocStart(),
                               ProtoArgType,
                               ProtoArgType,
                               diag::err_call_incomplete_argument, Arg))
                               diag::err_call_incomplete_argument, Arg))
         return true;
         return true;

+ 17 - 0
tools/clang/test/CodeGenHLSL/resource-array-param.hlsl

@@ -0,0 +1,17 @@
+// RUN: %dxc -E main -T ps_6_0 %s
+
+Texture2D Tex4[4];
+Texture2D Tex[];
+
+float4 lookup(Texture2D tex[], int3 coord) {
+  return tex[0].Load(coord);
+}
+
+float4 lookup4(Texture2D tex[4], int3 coord) {
+  return tex[0].Load(coord);
+}
+
+float4 main() : SV_Target
+{
+  return lookup(Tex, 0) + lookup(Tex4, 1) + lookup4(Tex, 2) + lookup4(Tex4, 3);
+}

+ 5 - 0
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -781,6 +781,7 @@ public:
   TEST_METHOD(CodeGenResourceInTB2)
   TEST_METHOD(CodeGenResourceInTB2)
   TEST_METHOD(CodeGenResourceInTBV)
   TEST_METHOD(CodeGenResourceInTBV)
   TEST_METHOD(CodeGenResourceInTBV2)
   TEST_METHOD(CodeGenResourceInTBV2)
+  TEST_METHOD(CodeGenResourceArrayParam)
   TEST_METHOD(CodeGenResPhi)
   TEST_METHOD(CodeGenResPhi)
   TEST_METHOD(CodeGenResPhi2)
   TEST_METHOD(CodeGenResPhi2)
   TEST_METHOD(CodeGenRootSigEntry)
   TEST_METHOD(CodeGenRootSigEntry)
@@ -4106,6 +4107,10 @@ TEST_F(CompilerTest, CodeGenResourceInTBV2) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\resource-in-tbv2.hlsl");
   CodeGenTestCheck(L"..\\CodeGenHLSL\\resource-in-tbv2.hlsl");
 }
 }
 
 
+TEST_F(CompilerTest, CodeGenResourceArrayParam) {
+  CodeGenTest(L"..\\CodeGenHLSL\\resource-array-param.hlsl");
+}
+
 TEST_F(CompilerTest, CodeGenResPhi) {
 TEST_F(CompilerTest, CodeGenResPhi) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\resPhi.hlsl");
   CodeGenTestCheck(L"..\\CodeGenHLSL\\resPhi.hlsl");
 }
 }