Kaynağa Gözat

Support resource inside cbuffer. (#175)

Xiang Li 8 yıl önce
ebeveyn
işleme
2b4f3e4801

+ 4 - 1
include/dxc/HLSL/DxilMetadataHelper.h

@@ -173,7 +173,10 @@ public:
   static const char kDxilControlFlowHintMDName[];
 
   // Resource attribute.
-  static const char kDxilResourceAttributeMDName[];
+  static const char kHLDxilResourceAttributeMDName[];
+  static const unsigned kHLDxilResourceAttributeNumFields = 2;
+  static const unsigned kHLDxilResourceAttributeClass = 0;
+  static const unsigned kHLDxilResourceAttributeMeta = 1;
 
   // Precise attribute.
   static const char kDxilPreciseAttributeMDName[];

+ 4 - 0
include/dxc/HLSL/HLModule.h

@@ -188,6 +188,8 @@ public:
   llvm::MDNode *DxilUAVToMDNode(const DxilResource &UAV);
   llvm::MDNode *DxilCBufferToMDNode(const DxilCBuffer &CB);
   DxilResourceBase LoadDxilResourceBaseFromMDNode(llvm::MDNode *MD);
+  void AddResourceWithGlobalVariableAndMDNode(llvm::Constant *GV,
+                                              llvm::MDNode *MD);
 
   // Type related methods.
   static bool IsStreamOutputPtrType(llvm::Type *Ty);
@@ -203,6 +205,8 @@ public:
                                       DxilParameterAnnotation &paramAnnotation);
   static const char *GetLegacyDataLayoutDesc();
 
+  static void MergeGepUse(llvm::Value *V);
+
   // HL code gen.
   template<class BuilderTy>
   static llvm::CallInst *EmitHLOperationCall(BuilderTy &Builder,

+ 2 - 0
lib/HLSL/DxilGenerationPass.cpp

@@ -2259,6 +2259,8 @@ void DxilGenerationPass::GenerateDxilCBufferHandles(
   for (size_t i = 0; i < m_pHLModule->GetCBuffers().size(); i++) {
     DxilCBuffer &CB = m_pHLModule->GetCBuffer(i);
     GlobalVariable *GV = cast<GlobalVariable>(CB.GetGlobalSymbol());
+    // Remove GEP created in HLObjectOperationLowerHelper::UniformCbPtr.
+    GV->removeDeadConstantUsers();
     std::string handleName = std::string(GV->getName()) + "_buffer";
 
     Value *args[] = {opArg, resClassArg, nullptr, nullptr,

+ 1 - 1
lib/HLSL/DxilMetadataHelper.cpp

@@ -44,7 +44,7 @@ const char DxilMDHelper::kDxilTypeSystemMDName[]                      = "dx.type
 const char DxilMDHelper::kDxilTypeSystemHelperVariablePrefix[]        = "dx.typevar.";
 const char DxilMDHelper::kDxilControlFlowHintMDName[]                 = "dx.controlflow.hints";
 const char DxilMDHelper::kDxilPreciseAttributeMDName[]                = "dx.precise";
-const char DxilMDHelper::kDxilResourceAttributeMDName[]               = "dx.resource.attribute";
+const char DxilMDHelper::kHLDxilResourceAttributeMDName[]             = "dx.hl.resource.attribute";
 const char DxilMDHelper::kDxilValidatorVersionMDName[]                = "dx.valver";
 
 // This named metadata is not valid in final module (should be moved to DxilContainer)

+ 145 - 12
lib/HLSL/HLModule.cpp

@@ -25,6 +25,7 @@
 #include "llvm/IR/DebugInfo.h"
 #include "llvm/IR/DIBuilder.h"
 #include "llvm/Support/raw_ostream.h"
+#include "llvm/IR/GetElementPtrTypeIterator.h"
 
 using namespace llvm;
 using std::string;
@@ -741,35 +742,34 @@ MDNode *HLModule::DxilCBufferToMDNode(const DxilCBuffer &CB) {
 
 DxilResourceBase HLModule::LoadDxilResourceBaseFromMDNode(
                                               MDNode *MD) {
-  const unsigned kDxilResourceAttributeNumFields = 2;
-  const unsigned kDxilResourceAttributeClass = 0;
-  const unsigned kDxilResourceAttributeMeta = 1;
-  IFTBOOL(MD->getNumOperands() >= kDxilResourceAttributeNumFields,
+  IFTBOOL(MD->getNumOperands() >= DxilMDHelper::kHLDxilResourceAttributeNumFields,
           DXC_E_INCORRECT_DXIL_METADATA);
 
   DxilResource::Class RC =
       static_cast<DxilResource::Class>(m_pMDHelper->ConstMDToUint32(
-          MD->getOperand(kDxilResourceAttributeClass)));
+          MD->getOperand(DxilMDHelper::kHLDxilResourceAttributeClass)));
+  const MDOperand &Meta =
+      MD->getOperand(DxilMDHelper::kHLDxilResourceAttributeMeta);
+
   switch (RC) {
   case DxilResource::Class::CBuffer: {
     DxilCBuffer CB;
-    m_pMDHelper->LoadDxilCBuffer(MD->getOperand(kDxilResourceAttributeMeta),
-                                 CB);
+    m_pMDHelper->LoadDxilCBuffer(Meta, CB);
     return CB;
   } break;
   case DxilResource::Class::Sampler: {
     DxilSampler S;
-    m_pMDHelper->LoadDxilSampler(MD->getOperand(kDxilResourceAttributeMeta), S);
+    m_pMDHelper->LoadDxilSampler(Meta, S);
     return S;
   } break;
   case DxilResource::Class::SRV: {
     DxilResource Res;
-    m_pMDHelper->LoadDxilSRV(MD->getOperand(kDxilResourceAttributeMeta), Res);
+    m_pMDHelper->LoadDxilSRV(Meta, Res);
     return Res;
   } break;
   case DxilResource::Class::UAV: {
     DxilResource Res;
-    m_pMDHelper->LoadDxilUAV(MD->getOperand(kDxilResourceAttributeMeta), Res);
+    m_pMDHelper->LoadDxilUAV(Meta, Res);
     return Res;
   } break;
   default:
@@ -778,6 +778,51 @@ DxilResourceBase HLModule::LoadDxilResourceBaseFromMDNode(
   }
 }
 
+void HLModule::AddResourceWithGlobalVariableAndMDNode(llvm::Constant *GV,
+                                                      llvm::MDNode *MD) {
+  IFTBOOL(MD->getNumOperands() >= DxilMDHelper::kHLDxilResourceAttributeNumFields,
+          DXC_E_INCORRECT_DXIL_METADATA);
+
+  DxilResource::Class RC =
+      static_cast<DxilResource::Class>(m_pMDHelper->ConstMDToUint32(
+          MD->getOperand(DxilMDHelper::kHLDxilResourceAttributeClass)));
+  const MDOperand &Meta =
+      MD->getOperand(DxilMDHelper::kHLDxilResourceAttributeMeta);
+  unsigned rangeSize = 1;
+  Type *Ty = GV->getType()->getPointerElementType();
+  if (ArrayType *AT = dyn_cast<ArrayType>(Ty))
+    rangeSize = AT->getNumElements();
+
+  switch (RC) {
+  case DxilResource::Class::Sampler: {
+    std::unique_ptr<DxilSampler> S = std::make_unique<DxilSampler>();
+    m_pMDHelper->LoadDxilSampler(Meta, *S);
+    S->SetGlobalSymbol(GV);
+    S->SetGlobalName(GV->getName());
+    S->SetRangeSize(rangeSize);
+    AddSampler(std::move(S));
+  } break;
+  case DxilResource::Class::SRV: {
+    std::unique_ptr<HLResource> Res = std::make_unique<HLResource>();
+    m_pMDHelper->LoadDxilSRV(Meta, *Res);
+    Res->SetGlobalSymbol(GV);
+    Res->SetGlobalName(GV->getName());
+    Res->SetRangeSize(rangeSize);
+    AddSRV(std::move(Res));
+  } break;
+  case DxilResource::Class::UAV: {
+    std::unique_ptr<HLResource> Res = std::make_unique<HLResource>();
+    m_pMDHelper->LoadDxilUAV(Meta, *Res);
+    Res->SetGlobalSymbol(GV);
+    Res->SetGlobalName(GV->getName());
+    Res->SetRangeSize(rangeSize);
+    AddUAV(std::move(Res));
+  } break;
+  default:
+    DXASSERT(0, "Invalid metadata for AddResourceWithGlobalVariableAndMDNode");
+  }
+}
+
 // TODO: Don't check names.
 bool HLModule::IsStreamOutputType(llvm::Type *Ty) {
   if (StructType *ST = dyn_cast<StructType>(Ty)) {
@@ -958,6 +1003,94 @@ const char *HLModule::GetLegacyDataLayoutDesc() {
   return kLegacyLayoutString.data();
 }
 
+static Value *MergeGEP(GEPOperator *SrcGEP, GetElementPtrInst *GEP) {
+  IRBuilder<> Builder(GEP);
+  SmallVector<Value *, 8> Indices;
+
+  // Find out whether the last index in the source GEP is a sequential idx.
+  bool EndsWithSequential = false;
+  for (gep_type_iterator I = gep_type_begin(*SrcGEP), E = gep_type_end(*SrcGEP);
+       I != E; ++I)
+    EndsWithSequential = !(*I)->isStructTy();
+  if (EndsWithSequential) {
+    Value *Sum;
+    Value *SO1 = SrcGEP->getOperand(SrcGEP->getNumOperands() - 1);
+    Value *GO1 = GEP->getOperand(1);
+    if (SO1 == Constant::getNullValue(SO1->getType())) {
+      Sum = GO1;
+    } else if (GO1 == Constant::getNullValue(GO1->getType())) {
+      Sum = SO1;
+    } else {
+      // If they aren't the same type, then the input hasn't been processed
+      // by the loop above yet (which canonicalizes sequential index types to
+      // intptr_t).  Just avoid transforming this until the input has been
+      // normalized.
+      if (SO1->getType() != GO1->getType())
+        return nullptr;
+      // Only do the combine when GO1 and SO1 are both constants. Only in
+      // this case, we are sure the cost after the merge is never more than
+      // that before the merge.
+      if (!isa<Constant>(GO1) || !isa<Constant>(SO1))
+        return nullptr;
+      Sum = Builder.CreateAdd(SO1, GO1);
+    }
+
+    // Update the GEP in place if possible.
+    if (SrcGEP->getNumOperands() == 2) {
+      GEP->setOperand(0, SrcGEP->getOperand(0));
+      GEP->setOperand(1, Sum);
+      return GEP;
+    }
+    Indices.append(SrcGEP->op_begin() + 1, SrcGEP->op_end() - 1);
+    Indices.push_back(Sum);
+    Indices.append(GEP->op_begin() + 2, GEP->op_end());
+  } else if (isa<Constant>(*GEP->idx_begin()) &&
+             cast<Constant>(*GEP->idx_begin())->isNullValue() &&
+             SrcGEP->getNumOperands() != 1) {
+    // Otherwise we can do the fold if the first index of the GEP is a zero
+    Indices.append(SrcGEP->op_begin() + 1, SrcGEP->op_end());
+    Indices.append(GEP->idx_begin() + 1, GEP->idx_end());
+  }
+  if (!Indices.empty())
+    return Builder.CreateInBoundsGEP(SrcGEP->getSourceElementType(),
+                                     SrcGEP->getOperand(0), Indices,
+                                     GEP->getName());
+  else
+    llvm_unreachable("must merge");
+}
+
+void HLModule::MergeGepUse(Value *V) {
+  for (auto U = V->user_begin(); U != V->user_end();) {
+    auto Use = U++;
+
+    if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(*Use)) {
+      if (GEPOperator *prevGEP = dyn_cast<GEPOperator>(V)) {
+        // merge the 2 GEPs
+        Value *newGEP = MergeGEP(prevGEP, GEP);
+        GEP->replaceAllUsesWith(newGEP);
+        GEP->eraseFromParent();
+        MergeGepUse(newGEP);
+      } else {
+        MergeGepUse(*Use);
+      }
+    } else if (GEPOperator *GEPOp = dyn_cast<GEPOperator>(*Use)) {
+      if (GEPOperator *prevGEP = dyn_cast<GEPOperator>(V)) {
+        // merge the 2 GEPs
+        Value *newGEP = MergeGEP(prevGEP, GEP);
+        GEP->replaceAllUsesWith(newGEP);
+        GEP->eraseFromParent();
+        MergeGepUse(newGEP);
+      } else {
+        MergeGepUse(*Use);
+      }
+    }
+  }
+  if (V->user_empty()) {
+    if (Instruction *I = dyn_cast<Instruction>(V))
+      I->eraseFromParent();
+  }
+}
+
 template<typename BuilderTy>
 CallInst *HLModule::EmitHLOperationCall(BuilderTy &Builder,
                                            HLOpcodeGroup group, unsigned opcode,
@@ -1128,11 +1261,11 @@ bool HLModule::HasPreciseAttribute(Function *F) {
 }
 
 void HLModule::MarkDxilResourceAttrib(llvm::Function *F, MDNode *MD) {
-  F->setMetadata(DxilMDHelper::kDxilResourceAttributeMDName, MD);
+  F->setMetadata(DxilMDHelper::kHLDxilResourceAttributeMDName, MD);
 }
 
 MDNode *HLModule::GetDxilResourceAttrib(llvm::Function *F) {
-  return F->getMetadata(DxilMDHelper::kDxilResourceAttributeMDName);
+  return F->getMetadata(DxilMDHelper::kHLDxilResourceAttributeMDName);
 }
 
 DIGlobalVariable *

+ 163 - 19
lib/HLSL/HLOperationLower.cpp

@@ -67,18 +67,25 @@ private:
   std::unordered_map<Value *, ResAttribute> HandleMetaMap;
   std::unordered_set<LoadInst *> &UpdateCounterSet;
   std::unordered_set<Value *> &NonUniformSet;
+  // Map from pointer of cbuffer to pointer of resource.
+  // For cbuffer like this:
+  //   cbuffer A {
+  //     Texture2D T;
+  //   };
+  // A global resource Texture2D T2 will be created for Texture2D T.
+  // CBPtrToResourceMap[T] will return T2.
+  std::unordered_map<Value *, Value *> CBPtrToResourceMap;
 
 public:
   HLObjectOperationLowerHelper(HLModule &HLM,
                                std::unordered_set<LoadInst *> &UpdateCounter,
                                std::unordered_set<Value *> &NonUniform)
       : HLM(HLM), UpdateCounterSet(UpdateCounter), NonUniformSet(NonUniform) {}
-
   DXIL::ResourceClass GetRC(Value *Handle) {
     ResAttribute &Res = FindCreateHandleResourceBase(Handle);
     return Res.RC;
   }
-  DXIL::ResourceKind  GetRK(Value *Handle) {
+  DXIL::ResourceKind GetRK(Value *Handle) {
     ResAttribute &Res = FindCreateHandleResourceBase(Handle);
     return Res.RK;
   }
@@ -89,13 +96,50 @@ public:
 
   void MarkHasCounter(Type *Ty, Value *handle) {
     DXIL::ResourceClass RC = GetRC(handle);
-    DXASSERT_LOCALVAR(RC, RC == DXIL::ResourceClass::UAV, "must UAV for counter");
-    std::unordered_set<Value*> resSet;
+    DXASSERT_LOCALVAR(RC, RC == DXIL::ResourceClass::UAV,
+                      "must UAV for counter");
+    std::unordered_set<Value *> resSet;
     MarkHasCounterOnCreateHandle(handle, resSet);
   }
-  void MarkNonUniform(Value *V) {
-    NonUniformSet.insert(V);
+  void MarkNonUniform(Value *V) { NonUniformSet.insert(V); }
+
+  Value *GetOrCreateResourceForCbPtr(GetElementPtrInst *CbPtr,
+                                     GlobalVariable *CbGV, MDNode *MD) {
+    // Change array idx to 0 to make sure all array ptr share same key.
+    Value *Key = UniformCbPtr(CbPtr, CbGV);
+    if (CBPtrToResourceMap.count(Key))
+      return CBPtrToResourceMap[Key];
+    Value *Resource = CreateResourceForCbPtr(CbPtr, CbGV, MD);
+    CBPtrToResourceMap[Key] = Resource;
+    return Resource;
+  }
+
+  Value *LowerCbResourcePtr(GetElementPtrInst *CbPtr, Value *ResPtr) {
+    // Simple case.
+    if (ResPtr->getType() == CbPtr->getType())
+      return ResPtr;
+
+    // Array case.
+    DXASSERT_NOMSG(ResPtr->getType()->getPointerElementType()->isArrayTy());
+
+    IRBuilder<> Builder(CbPtr);
+    gep_type_iterator GEPIt = gep_type_begin(CbPtr), E = gep_type_end(CbPtr);
+
+    Value *arrayIdx = GEPIt.getOperand();
+
+    // Only calc array idx and size.
+    // Ignore struct type part.
+    for (; GEPIt != E; ++GEPIt) {
+      if (GEPIt->isArrayTy()) {
+        arrayIdx = Builder.CreateMul(
+            arrayIdx, Builder.getInt32(GEPIt->getArrayNumElements()));
+        arrayIdx = Builder.CreateAdd(arrayIdx, GEPIt.getOperand());
+      }
+    }
+
+    return Builder.CreateGEP(ResPtr, {Builder.getInt32(0), arrayIdx});
   }
+
 private:
   ResAttribute &FindCreateHandleResourceBase(Value *Handle) {
     if (HandleMetaMap.count(Handle))
@@ -146,7 +190,8 @@ private:
 
     return HandleMetaMap[Handle];
   }
-  CallInst *FindCreateHandle(Value *handle, std::unordered_set<Value *> &resSet) {
+  CallInst *FindCreateHandle(Value *handle,
+                             std::unordered_set<Value *> &resSet) {
     // Already checked.
     if (resSet.count(handle))
       return nullptr;
@@ -156,9 +201,9 @@ private:
       return CI;
     if (SelectInst *Sel = dyn_cast<SelectInst>(handle)) {
       if (CallInst *CI = FindCreateHandle(Sel->getTrueValue(), resSet))
-          return CI;
+        return CI;
       if (CallInst *CI = FindCreateHandle(Sel->getFalseValue(), resSet))
-          return CI;
+        return CI;
       return nullptr;
     }
     if (PHINode *Phi = dyn_cast<PHINode>(handle)) {
@@ -179,7 +224,8 @@ private:
     resSet.insert(handle);
 
     if (CallInst *CI = dyn_cast<CallInst>(handle)) {
-      Value *Res = CI->getArgOperand(HLOperandIndex::kCreateHandleResourceOpIdx);
+      Value *Res =
+          CI->getArgOperand(HLOperandIndex::kCreateHandleResourceOpIdx);
       LoadInst *LdRes = dyn_cast<LoadInst>(Res);
       if (!LdRes) {
         CI->getContext().emitError(CI, "cannot map resource to handle");
@@ -198,6 +244,67 @@ private:
       }
     }
   }
+
+  Value *UniformCbPtr(GetElementPtrInst *CbPtr, GlobalVariable *CbGV) {
+    gep_type_iterator GEPIt = gep_type_begin(CbPtr), E = gep_type_end(CbPtr);
+    std::vector<Value *> idxList(CbPtr->idx_begin(), CbPtr->idx_end());
+    unsigned i = 0;
+    IRBuilder<> Builder(HLM.GetCtx());
+    Value *zero = Builder.getInt32(0);
+    for (; GEPIt != E; ++GEPIt, ++i) {
+      if (GEPIt->isArrayTy()) {
+        // Change array idx to 0 to make sure all array ptr share same key.
+        idxList[i] = zero;
+      }
+    }
+
+    Value *Key = Builder.CreateInBoundsGEP(CbGV, idxList);
+    return Key;
+  }
+
+  Value *CreateResourceForCbPtr(GetElementPtrInst *CbPtr, GlobalVariable *CbGV,
+                                MDNode *MD) {
+    Type *CbTy = CbPtr->getPointerOperandType();
+    DXASSERT_NOMSG(CbTy == CbGV->getType());
+
+    gep_type_iterator GEPIt = gep_type_begin(CbPtr), E = gep_type_end(CbPtr);
+    unsigned i = 0;
+    IRBuilder<> Builder(HLM.GetCtx());
+    unsigned arraySize = 1;
+    DxilTypeSystem &typeSys = HLM.GetTypeSystem();
+
+    std::string Name;
+    for (; GEPIt != E; ++GEPIt, ++i) {
+      if (GEPIt->isArrayTy()) {
+        arraySize *= GEPIt->getArrayNumElements();
+      } else if (GEPIt->isStructTy()) {
+        DxilStructAnnotation *typeAnnot =
+            typeSys.GetStructAnnotation(cast<StructType>(*GEPIt));
+        DXASSERT_NOMSG(typeAnnot);
+        unsigned idx = cast<ConstantInt>(GEPIt.getOperand())->getLimitedValue();
+        DXASSERT_NOMSG(typeAnnot->GetNumFields() > idx);
+        DxilFieldAnnotation &fieldAnnot = typeAnnot->GetFieldAnnotation(idx);
+        if (!Name.empty())
+          Name += ".";
+        Name += fieldAnnot.GetFieldName();
+      }
+    }
+
+    Type *Ty = CbPtr->getResultElementType();
+    if (arraySize > 1) {
+      Ty = ArrayType::get(Ty, arraySize);
+    }
+
+    return CreateResourceGV(Ty, Name, MD);
+  }
+
+  Value *CreateResourceGV(Type *Ty, StringRef Name, MDNode *MD) {
+    Module &M = *HLM.GetModule();
+    Constant *GV = M.getOrInsertGlobal(Name, Ty);
+    // Create resource and set GV as globalSym.
+    HLM.AddResourceWithGlobalVariableAndMDNode(GV, MD);
+    return GV;
+  }
 };
 
 using IntrinsicLowerFuncTy = Value *(CallInst *CI, IntrinsicOp IOP,
@@ -4697,14 +4804,38 @@ void TranslateCBGepLegacy(GetElementPtrInst *GEP, Value *handle,
                           Value *legacyIdx, unsigned channelOffset,
                           hlsl::OP *hlslOP, IRBuilder<> &Builder,
                           DxilFieldAnnotation *prevFieldAnnotation,
-                          const DataLayout &DL, DxilTypeSystem &dxilTypeSys);
+                          const DataLayout &DL, DxilTypeSystem &dxilTypeSys,
+                          HLObjectOperationLowerHelper *pObjHelper);
+
+void TranslateResourceInCB(LoadInst *LI,
+                           HLObjectOperationLowerHelper *pObjHelper,
+                           GlobalVariable *CbGV) {
+  if (LI->user_empty()) {
+    LI->eraseFromParent();
+    return;
+  }
+
+  GetElementPtrInst *Ptr = cast<GetElementPtrInst>(LI->getPointerOperand());
+  CallInst *CI = cast<CallInst>(LI->user_back());
+  MDNode *MD = HLModule::GetDxilResourceAttrib(CI->getCalledFunction());
+
+  Value *ResPtr = pObjHelper->GetOrCreateResourceForCbPtr(Ptr, CbGV, MD);
+
+  // Lower Ptr to GV base Ptr.
+  Value *GvPtr = pObjHelper->LowerCbResourcePtr(Ptr, ResPtr);
+  IRBuilder<> Builder(LI);
+  Value *GvLd = Builder.CreateLoad(GvPtr);
+  LI->replaceAllUsesWith(GvLd);
+  LI->eraseFromParent();
+}
 
 void TranslateCBAddressUserLegacy(Instruction *user, Value *handle,
                                   Value *legacyIdx, unsigned channelOffset,
                                   hlsl::OP *hlslOP,
                                   DxilFieldAnnotation *prevFieldAnnotation,
                                   DxilTypeSystem &dxilTypeSys,
-                                  const DataLayout &DL) {
+                                  const DataLayout &DL,
+                                  HLObjectOperationLowerHelper *pObjHelper) {
   Value *zeroIdx = hlslOP->GetU32Const(0);
 
   IRBuilder<> Builder(user);
@@ -4879,6 +5010,14 @@ void TranslateCBAddressUserLegacy(Instruction *user, Value *handle,
   } else if (LoadInst *ldInst = dyn_cast<LoadInst>(user)) {
     Type *Ty = ldInst->getType();
     Type *EltTy = Ty->getScalarType();
+    // Resource inside cbuffer is lowered after GenerateDxilOperations.
+    if (HLModule::IsHLSLObjectType(Ty)) {
+      CallInst *CI = cast<CallInst>(handle);
+      GlobalVariable *CbGV = cast<GlobalVariable>(
+          CI->getArgOperand(HLOperandIndex::kCreateHandleResourceOpIdx));
+      TranslateResourceInCB(ldInst, pObjHelper, CbGV);
+      return;
+    }
     DXASSERT(!Ty->isAggregateType(), "should be flat in previous pass");
     
     Value *newLd = nullptr;
@@ -4897,7 +5036,7 @@ void TranslateCBAddressUserLegacy(Instruction *user, Value *handle,
     // Must be GEP here
     GetElementPtrInst *GEP = cast<GetElementPtrInst>(user);
     TranslateCBGepLegacy(GEP, handle, legacyIdx, channelOffset, hlslOP, Builder,
-                         prevFieldAnnotation, DL, dxilTypeSys);
+                         prevFieldAnnotation, DL, dxilTypeSys, pObjHelper);
     GEP->eraseFromParent();
   }
 }
@@ -4906,7 +5045,8 @@ void TranslateCBGepLegacy(GetElementPtrInst *GEP, Value *handle,
                           Value *legacyIndex, unsigned channel,
                           hlsl::OP *hlslOP, IRBuilder<> &Builder,
                           DxilFieldAnnotation *prevFieldAnnotation,
-                          const DataLayout &DL, DxilTypeSystem &dxilTypeSys) {
+                          const DataLayout &DL, DxilTypeSystem &dxilTypeSys,
+                          HLObjectOperationLowerHelper *pObjHelper) {
   SmallVector<Value *, 8> Indices(GEP->idx_begin(), GEP->idx_end());
 
   // update offset
@@ -5060,20 +5200,23 @@ void TranslateCBGepLegacy(GetElementPtrInst *GEP, Value *handle,
     Instruction *user = cast<Instruction>(*(U++));
 
     TranslateCBAddressUserLegacy(user, handle, legacyIndex, channel, hlslOP, fieldAnnotation,
-                           dxilTypeSys, DL);
+                           dxilTypeSys, DL, pObjHelper);
   }
 }
 
 void TranslateCBOperationsLegacy(Value *handle, Value *ptr, OP *hlslOP,
-                           DxilTypeSystem &dxilTypeSys, const DataLayout &DL) {
+                                 DxilTypeSystem &dxilTypeSys,
+                                 const DataLayout &DL,
+                                 HLObjectOperationLowerHelper *pObjHelper) {
   auto User = ptr->user_begin();
   auto UserE = ptr->user_end();
   Value *zeroIdx = hlslOP->GetU32Const(0);
   for (; User != UserE;) {
     // Must be Instruction.
     Instruction *I = cast<Instruction>(*(User++));
-    TranslateCBAddressUserLegacy(I, handle, zeroIdx, /*channelOffset*/0, hlslOP,
-                           /*prevFieldAnnotation*/ nullptr, dxilTypeSys, DL);
+    TranslateCBAddressUserLegacy(
+        I, handle, zeroIdx, /*channelOffset*/ 0, hlslOP,
+        /*prevFieldAnnotation*/ nullptr, dxilTypeSys, DL, pObjHelper);
   }
 }
 
@@ -6068,11 +6211,12 @@ void TranslateHLSubscript(CallInst *CI, HLSubscriptOpcode opcode,
 
   Value *ptr = CI->getArgOperand(HLOperandIndex::kSubscriptObjectOpIdx);
   if (opcode == HLSubscriptOpcode::CBufferSubscript) {
+    HLModule::MergeGepUse(CI);
     // Resource ptr.
     Value *handle = CI->getArgOperand(HLOperandIndex::kSubscriptObjectOpIdx);
     if (helper.bLegacyCBufferLoad)
       TranslateCBOperationsLegacy(handle, CI, hlslOP, helper.dxilTypeSys,
-                                  helper.legacyDataLayout);
+                                  helper.legacyDataLayout, pObjHelper);
     else {
       TranslateCBOperations(handle, CI, /*offset*/ hlslOP->GetU32Const(0),
                             hlslOP, helper.dxilTypeSys,

+ 10 - 98
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -1498,95 +1498,6 @@ bool SROA_HLSL::ShouldAttemptScalarRepl(AllocaInst *AI) {
   return false;
 }
 
-static Value *MergeGEP(GEPOperator *SrcGEP, GetElementPtrInst *GEP) {
-  IRBuilder<> Builder(GEP);
-  SmallVector<Value *, 8> Indices;
-
-  // Find out whether the last index in the source GEP is a sequential idx.
-  bool EndsWithSequential = false;
-  for (gep_type_iterator I = gep_type_begin(*SrcGEP), E = gep_type_end(*SrcGEP);
-       I != E; ++I)
-    EndsWithSequential = !(*I)->isStructTy();
-  if (EndsWithSequential) {
-    Value *Sum;
-    Value *SO1 = SrcGEP->getOperand(SrcGEP->getNumOperands() - 1);
-    Value *GO1 = GEP->getOperand(1);
-    if (SO1 == Constant::getNullValue(SO1->getType())) {
-      Sum = GO1;
-    } else if (GO1 == Constant::getNullValue(GO1->getType())) {
-      Sum = SO1;
-    } else {
-      // If they aren't the same type, then the input hasn't been processed
-      // by the loop above yet (which canonicalizes sequential index types to
-      // intptr_t).  Just avoid transforming this until the input has been
-      // normalized.
-      if (SO1->getType() != GO1->getType())
-        return nullptr;
-      // Only do the combine when GO1 and SO1 are both constants. Only in
-      // this case, we are sure the cost after the merge is never more than
-      // that before the merge.
-      if (!isa<Constant>(GO1) || !isa<Constant>(SO1))
-        return nullptr;
-      Sum = Builder.CreateAdd(SO1, GO1);
-    }
-
-    // Update the GEP in place if possible.
-    if (SrcGEP->getNumOperands() == 2) {
-      GEP->setOperand(0, SrcGEP->getOperand(0));
-      GEP->setOperand(1, Sum);
-      return GEP;
-    }
-    Indices.append(SrcGEP->op_begin() + 1, SrcGEP->op_end() - 1);
-    Indices.push_back(Sum);
-    Indices.append(GEP->op_begin() + 2, GEP->op_end());
-  } else if (isa<Constant>(*GEP->idx_begin()) &&
-             cast<Constant>(*GEP->idx_begin())->isNullValue() &&
-             SrcGEP->getNumOperands() != 1) {
-    // Otherwise we can do the fold if the first index of the GEP is a zero
-    Indices.append(SrcGEP->op_begin() + 1, SrcGEP->op_end());
-    Indices.append(GEP->idx_begin() + 1, GEP->idx_end());
-  }
-  if (!Indices.empty())
-    return Builder.CreateInBoundsGEP(SrcGEP->getSourceElementType(),
-                                     SrcGEP->getOperand(0), Indices,
-                                     GEP->getName());
-  else
-    llvm_unreachable("must merge");
-}
-
-static void MergeGepUse(Value *V) {
-  for (auto U = V->user_begin(); U != V->user_end();) {
-    auto Use = U++;
-
-    if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(*Use)) {
-      if (GEPOperator *prevGEP = dyn_cast<GEPOperator>(V)) {
-        // merge the 2 GEPs
-        Value *newGEP = MergeGEP(prevGEP, GEP);
-        GEP->replaceAllUsesWith(newGEP);
-        GEP->eraseFromParent();
-        MergeGepUse(newGEP);
-      } else {
-        MergeGepUse(*Use);
-      }
-    }
-    else if (GEPOperator *GEPOp = dyn_cast<GEPOperator>(*Use)) {
-      if (GEPOperator *prevGEP = dyn_cast<GEPOperator>(V)) {
-        // merge the 2 GEPs
-        Value *newGEP = MergeGEP(prevGEP, GEP);
-        GEP->replaceAllUsesWith(newGEP);
-        GEP->eraseFromParent();
-        MergeGepUse(newGEP);
-      } else {
-        MergeGepUse(*Use);
-      }
-    }
-  }
-  if (V->user_empty()) {
-    if (Instruction *I = dyn_cast<Instruction>(V))
-      I->eraseFromParent();
-  }
-}
-
 // performScalarRepl - This algorithm is a simple worklist driven algorithm,
 // which runs on all of the alloca instructions in the entry block, removing
 // them if they are only used by getelementptr instructions.
@@ -1608,7 +1519,7 @@ bool SROA_HLSL::performScalarRepl(Function &F) {
 
   // merge GEP use for the allocs
   for (auto A : AllocaList)
-    MergeGepUse(A);
+    HLModule::MergeGepUse(A);
 
   DIBuilder DIB(*F.getParent(), /*AllowUnresolved*/ false);
 
@@ -3449,8 +3360,9 @@ public:
     for (Function &F : M.functions()) {
       HLOpcodeGroup group = GetHLOpcodeGroup(&F);
       // Skip HL operations.
-      if (group != HLOpcodeGroup::NotHL || group == HLOpcodeGroup::HLExtIntrinsic)
+      if (group != HLOpcodeGroup::NotHL || group == HLOpcodeGroup::HLExtIntrinsic) {
         continue;
+      }
 
       if (F.isDeclaration()) {
         // Skip llvm intrinsic.
@@ -3489,7 +3401,7 @@ public:
         staticGVs.emplace_back(&GV);
       } else {
         // merge GEP use for global.
-        MergeGepUse(&GV);
+        HLModule::MergeGepUse(&GV);
       }
     }
 
@@ -3631,7 +3543,7 @@ void SROA_Parameter_HLSL::flattenGlobal(GlobalVariable *GV) {
   std::deque<Value *> WorkList;
   WorkList.push_back(GV);
   // merge GEP use for global.
-  MergeGepUse(GV);
+  HLModule::MergeGepUse(GV);
   Function *Entry = m_pHLModule->GetEntryFunction();
   
   DxilTypeSystem &dxilTypeSys = m_pHLModule->GetTypeSystem();
@@ -4466,7 +4378,7 @@ void SROA_Parameter_HLSL::createFlattenedFunction(Function *F) {
   // Add all argument to worklist.
   for (Argument &Arg : F->args()) {
     // merge GEP use for arg.
-    MergeGepUse(&Arg);
+    HLModule::MergeGepUse(&Arg);
     // Insert point may be removed. So recreate builder every time.
     IRBuilder<> Builder(F->getEntryBlock().getFirstInsertionPt());
     DxilParameterAnnotation &paramAnnotation =
@@ -4650,7 +4562,7 @@ void SROA_Parameter_HLSL::createFlattenedFunction(Function *F) {
       DDI->setArgOperand(0, VMD);
     }
 
-    MergeGepUse(Arg);
+    HLModule::MergeGepUse(Arg);
     // Flatten store of array parameter.
     if (Arg->getType()->isPointerTy()) {
       Type *Ty = Arg->getType()->getPointerElementType();
@@ -5350,7 +5262,7 @@ bool DynamicIndexingVectorToArray::runOnModule(Module &M) {
   for (GlobalVariable &GV : M.globals()) {
     if (HLModule::IsStaticGlobal(&GV) || HLModule::IsSharedMemoryGlobal(&GV)) {
       // Merge all GEP.
-      MergeGepUse(&GV);
+      HLModule::MergeGepUse(&GV);
     }
   }
   return true;
@@ -5463,7 +5375,7 @@ void MultiDimArrayToOneDimArray::flattenAlloca(AllocaInst *AI) {
   IRBuilder<> Builder(AI);
   Value *NewAI = Builder.CreateAlloca(AT);
   // Merge all GEP of AI.
-  MergeGepUse(AI);
+  HLModule::MergeGepUse(AI);
 
   flattenMultiDimArray(AI, NewAI);
   AI->eraseFromParent();
@@ -5581,7 +5493,7 @@ bool MultiDimArrayToOneDimArray::runOnModule(Module &M) {
   for (GlobalVariable &GV : M.globals()) {
     if (HLModule::IsStaticGlobal(&GV) || HLModule::IsSharedMemoryGlobal(&GV)) {
       // Merge all GEP.
-      MergeGepUse(&GV);
+      HLModule::MergeGepUse(&GV);
       if (IsMultiDimArrayType(GV.getType()->getElementType()) &&
           !GV.user_empty())
         multiDimGVs.emplace_back(&GV);

+ 9 - 7
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -879,9 +879,12 @@ unsigned CGMSHLSLRuntime::AddTypeAnnotation(QualType Ty,
     DxilStructAnnotation *annotation = dxilTypeSys.AddStructAnnotation(ST);
 
     return ConstructStructAnnotation(annotation, RD, dxilTypeSys);
-  } else if (IsHLSLResouceType(Ty))
-    return AddTypeAnnotation(GetHLSLResourceResultType(Ty), dxilTypeSys, arrayEltSize);
-  else {
+  } else if (IsHLSLResouceType(Ty)) {
+    // Save result type info.
+    AddTypeAnnotation(GetHLSLResourceResultType(Ty), dxilTypeSys, arrayEltSize);
+    // Resource don't count for cbuffer size.
+    return 0;
+  } else {
     unsigned arraySize = 0;
     QualType arrayElementTy = Ty;
     if (Ty->isConstantArrayType()) {
@@ -2255,17 +2258,16 @@ void CGMSHLSLRuntime::AddConstant(VarDecl *constDecl, HLCBuffer &CB) {
     CGM.EmitGlobal(constDecl);
     return;
   }
-
   // Search defined structure for resource objects and fail
-  if (IsResourceInType(CGM.getContext(), constDecl->getType())) {
+  if (CB.GetRangeSize() > 1 &&
+      IsResourceInType(CGM.getContext(), constDecl->getType())) {
     DiagnosticsEngine &Diags = CGM.getDiags();
     unsigned DiagID = Diags.getCustomDiagID(
         DiagnosticsEngine::Error,
-        "object types not supported in global aggregate instances, cbuffers, or tbuffers.");
+        "object types not supported in cbuffer/tbuffer view arrays.");
     Diags.Report(constDecl->getLocation(), DiagID);
     return;
   }
-
   llvm::Constant *constVal = CGM.GetAddrOfGlobalVar(constDecl);
 
   bool isGlobalCB = CB.GetID() == globalCBIndex;

+ 3 - 0
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -10300,6 +10300,9 @@ bool Sema::DiagnoseHLSLDecl(Declarator &D, DeclContext *DC,
     hlslSource->AddHLSLObjectMethodsIfNotReady(qt);
   } else if (qt->isArrayType()) {
     QualType eltQt(qt->getArrayElementTypeNoTypeQual(), 0);
+    while (eltQt->isArrayType())
+      eltQt = QualType(eltQt->getArrayElementTypeNoTypeQual(), 0);
+
     if (hlsl::IsObjectType(this, eltQt, &bDeprecatedEffectObject)) {
       // Add methods if not ready.
       HLSLExternalSource *hlslSource = HLSLExternalSource::FromSema(this);

+ 1 - 1
tools/clang/test/CodeGenHLSL/resource-in-cb.hlsl

@@ -1,6 +1,6 @@
 // RUN: %dxc -E main -T ps_6_0 %s  | FileCheck %s
 
-// CHECK: error: object types not supported in global aggregate instances, cbuffers, or tbuffers.
+// CHECK: Tex1
 
 SamplerState Samp;
 cbuffer CB

+ 2 - 2
tools/clang/test/CodeGenHLSL/resource-in-cb2.hlsl

@@ -1,6 +1,6 @@
 // RUN: %dxc -E main -T ps_6_0 %s  | FileCheck %s
 
-// CHECK: error: object types not supported in global aggregate instances, cbuffers, or tbuffers.
+// CHECK: var.res.Tex1
 
 struct Resource
 {
@@ -25,5 +25,5 @@ cbuffer CB
 
 float4 main(int4 a : A, float4 coord : TEXCOORD) : SV_TARGET
 {
-  return var.res.Tex1.Sample(Samp, coord.xy) * var.foo;
+  return var.res.Tex1.Sample(Samp, coord.xy) * var.res.foo;
 }

+ 19 - 0
tools/clang/test/CodeGenHLSL/resource-in-cb3.hlsl

@@ -0,0 +1,19 @@
+// RUN: %dxc -E main -T ps_6_0 %s  | FileCheck %s
+
+// CHECK: Tex1                              texture     f32          2d      T0             t0     2
+
+SamplerState Samp;
+cbuffer CB
+{
+  Texture2D Tex1[2];
+  // Texture3D Tex2;
+  // RWTexture2D<float4> RWTex1;
+  // RWTexture3D<float4> RWTex2;
+  // SamplerState Samp;
+  float4 foo;
+};
+
+float4 main(int4 a : A, float4 coord : TEXCOORD) : SV_TARGET
+{
+  return Tex1[0].Sample(Samp, coord.xy) * foo;
+}

+ 23 - 0
tools/clang/test/CodeGenHLSL/resource-in-cb4.hlsl

@@ -0,0 +1,23 @@
+// RUN: %dxc -E main -T ps_6_0 %s  | FileCheck %s
+
+// CHECK: Tex1                              texture     f32          2d      T0             t0     4
+
+
+SamplerState Samp;
+
+cbuffer CB
+{
+  Texture2D<float4> Tex1[2][2];
+  // Texture3D Tex2;
+  // RWTexture2D<float4> RWTex1;
+  // RWTexture3D<float4> RWTex2;
+  // SamplerState Samp;
+  float4 foo;
+};
+
+uint i;
+
+float4 main(int4 a : A, float2 coord : TEXCOORD) : SV_TARGET
+{
+  return Tex1[0][i].Sample(Samp, coord) * foo;
+}

+ 1 - 1
tools/clang/test/CodeGenHLSL/resource-in-cbv.hlsl

@@ -1,6 +1,6 @@
 // RUN: %dxc -E main -T ps_6_0 %s  | FileCheck %s
 
-// CHECK: error: object types not supported in global aggregate instances, cbuffers, or tbuffers.
+// CHECK: error: object types not supported in cbuffer/tbuffer view arrays.
 
 SamplerState Samp;
 struct Resources

+ 1 - 1
tools/clang/test/CodeGenHLSL/resource-in-cbv2.hlsl

@@ -1,6 +1,6 @@
 // RUN: %dxc -E main -T ps_6_0 %s  | FileCheck %s
 
-// CHECK: error: object types not supported in global aggregate instances, cbuffers, or tbuffers.
+// CHECK: error: object types not supported in cbuffer/tbuffer view arrays.
 
 SamplerState Samp;
 struct Resource

+ 1 - 1
tools/clang/test/CodeGenHLSL/resource-in-struct.hlsl

@@ -1,6 +1,6 @@
 // RUN: %dxc -E main -T ps_6_0 %s  | FileCheck %s
 
-// CHECK: error: object types not supported in global aggregate instances, cbuffers, or tbuffers.
+// CHECK: res.Tex1
 
 SamplerState Samp;
 struct Resources

+ 1 - 1
tools/clang/test/CodeGenHLSL/resource-in-struct2.hlsl

@@ -1,6 +1,6 @@
 // RUN: %dxc -E main -T ps_6_0 %s  | FileCheck %s
 
-// CHECK: error: object types not supported in global aggregate instances, cbuffers, or tbuffers.
+// CHECK: var.res.Tex1
 
 SamplerState Samp;
 struct Resource

+ 27 - 0
tools/clang/test/CodeGenHLSL/resource-in-struct3.hlsl

@@ -0,0 +1,27 @@
+// RUN: %dxc -E main -T ps_6_0 %s  | FileCheck %s
+
+// CHECK: var.res.Tex1                      texture     f32          2d      T0             t0     2
+
+SamplerState Samp;
+struct Resource
+{
+  Texture2D Tex1[2];
+  // Texture3D Tex2;
+  // RWTexture2D<float4> RWTex1;
+  // RWTexture3D<float4> RWTex2;
+  // SamplerState Samp;
+  float4 foo;
+};
+
+struct MyStruct
+{
+  Resource res;
+  int4 bar;
+};
+
+MyStruct var;
+
+float4 main(int4 a : A, float4 coord : TEXCOORD) : SV_TARGET
+{
+  return var.res.Tex1[0].Sample(Samp, coord.xy) * var.res.foo;
+}

+ 1 - 1
tools/clang/test/CodeGenHLSL/resource-in-tb.hlsl

@@ -1,6 +1,6 @@
 // RUN: %dxc -E main -T ps_6_0 %s  | FileCheck %s
 
-// CHECK: error: object types not supported in global aggregate instances, cbuffers, or tbuffers.
+// CHECK: Tex1
 
 SamplerState Samp;
 tbuffer TB

+ 2 - 2
tools/clang/test/CodeGenHLSL/resource-in-tb2.hlsl

@@ -1,6 +1,6 @@
 // RUN: %dxc -E main -T ps_6_0 %s  | FileCheck %s
 
-// CHECK: error: object types not supported in global aggregate instances, cbuffers, or tbuffers.
+// CHECK: var.res.Tex1
 
 struct Resource
 {
@@ -26,5 +26,5 @@ tbuffer TB
 
 float4 main(int4 a : A, float4 coord : TEXCOORD) : SV_TARGET
 {
-  return var.res.Tex1.Sample(Samp, coord.xy) * var.foo;
+  return var.res.Tex1.Sample(Samp, coord.xy) * var.res.foo;
 }

+ 1 - 1
tools/clang/test/CodeGenHLSL/resource-in-tbv.hlsl

@@ -1,6 +1,6 @@
 // RUN: %dxc -E main -T ps_6_0 %s  | FileCheck %s
 
-// CHECK: error: object types not supported in global aggregate instances, cbuffers, or tbuffers.
+// CHECK: error: object types not supported in cbuffer/tbuffer view arrays.
 
 SamplerState Samp;
 struct Resources

+ 1 - 1
tools/clang/test/CodeGenHLSL/resource-in-tbv2.hlsl

@@ -1,6 +1,6 @@
 // RUN: %dxc -E main -T ps_6_0 %s  | FileCheck %s
 
-// CHECK: error: object types not supported in global aggregate instances, cbuffers, or tbuffers.
+// CHECK: error: object types not supported in cbuffer/tbuffer view arrays.
 
 SamplerState Samp;
 struct Resource

+ 20 - 5
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -614,15 +614,18 @@ public:
   TEST_METHOD(CodeGenBindings2)
   TEST_METHOD(CodeGenBindings3)
   TEST_METHOD(CodeGenResCopy)
-  TEST_METHOD(CodeGenResourceInStruct)
   TEST_METHOD(CodeGenResourceInCB)
-  TEST_METHOD(CodeGenResourceInCBV)
-  TEST_METHOD(CodeGenResourceInTB)
-  TEST_METHOD(CodeGenResourceInTBV)
-  TEST_METHOD(CodeGenResourceInStruct2)
   TEST_METHOD(CodeGenResourceInCB2)
+  TEST_METHOD(CodeGenResourceInCB3)
+  TEST_METHOD(CodeGenResourceInCB4)
+  TEST_METHOD(CodeGenResourceInCBV)
   TEST_METHOD(CodeGenResourceInCBV2)
+  TEST_METHOD(CodeGenResourceInStruct)
+  TEST_METHOD(CodeGenResourceInStruct2)
+  TEST_METHOD(CodeGenResourceInStruct3)
+  TEST_METHOD(CodeGenResourceInTB)
   TEST_METHOD(CodeGenResourceInTB2)
+  TEST_METHOD(CodeGenResourceInTBV)
   TEST_METHOD(CodeGenResourceInTBV2)
   TEST_METHOD(CodeGenResPhi)
   TEST_METHOD(CodeGenResPhi2)
@@ -3245,10 +3248,22 @@ TEST_F(CompilerTest, CodeGenResourceInStruct2) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\resource-in-struct2.hlsl");
 }
 
+TEST_F(CompilerTest, CodeGenResourceInStruct3) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\resource-in-struct3.hlsl");
+}
+
 TEST_F(CompilerTest, CodeGenResourceInCB2) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\resource-in-cb2.hlsl");
 }
 
+TEST_F(CompilerTest, CodeGenResourceInCB3) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\resource-in-cb3.hlsl");
+}
+
+TEST_F(CompilerTest, CodeGenResourceInCB4) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\resource-in-cb4.hlsl");
+}
+
 TEST_F(CompilerTest, CodeGenResourceInCBV2) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\resource-in-cbv2.hlsl");
 }