Explorar o código

Update Constant/TextureBuffer implementation. (#3147)

* Update Constant/TextureBuffer implementation.

* [spirv] Fixes needed for new Constant/TextureBuffer representation.


Co-authored-by: Ehsan Nasiri <[email protected]>
Xiang Li %!s(int64=5) %!d(string=hai) anos
pai
achega
8e3b9e7ed5
Modificáronse 28 ficheiros con 488 adicións e 259 borrados
  1. 0 3
      lib/DXIL/DxilUtil.cpp
  2. 9 1
      lib/HLSL/DxilContainerReflection.cpp
  3. 2 1
      lib/HLSL/DxilGenerationPass.cpp
  4. 1 1
      lib/HLSL/DxilValidation.cpp
  5. 2 0
      tools/clang/include/clang/AST/HlslTypes.h
  6. 30 10
      tools/clang/include/clang/Basic/Attr.td
  7. 0 2
      tools/clang/include/clang/Basic/TokenKinds.def
  8. 0 2
      tools/clang/include/clang/Parse/Parser.h
  9. 11 2
      tools/clang/include/clang/SPIRV/AstTypeProbe.h
  10. 1 1
      tools/clang/include/clang/Sema/SemaHLSL.h
  11. 21 0
      tools/clang/lib/AST/ASTContextHLSL.cpp
  12. 10 1
      tools/clang/lib/AST/HlslTypes.cpp
  13. 132 80
      tools/clang/lib/CodeGen/CGHLSLMS.cpp
  14. 26 12
      tools/clang/lib/CodeGen/CGHLSLMSFinishCodeGen.cpp
  15. 13 2
      tools/clang/lib/CodeGen/CGHLSLMSHelper.h
  16. 0 4
      tools/clang/lib/Parse/ParseDecl.cpp
  17. 0 54
      tools/clang/lib/Parse/ParseHLSL.cpp
  18. 0 2
      tools/clang/lib/Parse/Parser.cpp
  19. 23 8
      tools/clang/lib/SPIRV/AstTypeProbe.cpp
  20. 45 33
      tools/clang/lib/SPIRV/DeclResultIdMapper.cpp
  21. 50 23
      tools/clang/lib/SPIRV/SpirvEmitter.cpp
  22. 1 1
      tools/clang/lib/Sema/SemaExprMember.cpp
  23. 61 5
      tools/clang/lib/Sema/SemaHLSL.cpp
  24. 19 0
      tools/clang/test/HLSLFileCheck/hlsl/objects/Cbuffer/retCBV.hlsl
  25. 20 0
      tools/clang/test/HLSLFileCheck/hlsl/objects/Cbuffer/retTBV.hlsl
  26. 4 4
      tools/clang/test/HLSLFileCheck/hlsl/objects/CbufferLegacy/cbuffer-struct.hlsl
  27. 3 3
      tools/clang/test/HLSLFileCheck/hlsl/objects/CbufferLegacy/cbuffer-structarray.hlsl
  28. 4 4
      tools/clang/test/HLSLFileCheck/hlsl/objects/CbufferLegacy/ctbuf.hlsl

+ 0 - 3
lib/DXIL/DxilUtil.cpp

@@ -615,9 +615,6 @@ bool IsHLSLResourceType(llvm::Type *Ty) {
     if (name.startswith("ConsumeStructuredBuffer<"))
       return true;
 
-    if (name.startswith("ConstantBuffer<"))
-      return true;
-
     if (name == "RaytracingAccelerationStructure")
       return true;
 

+ 9 - 1
lib/HLSL/DxilContainerReflection.cpp

@@ -1262,7 +1262,15 @@ void CShaderReflectionConstantBuffer::Initialize(
 
     // Replicate fxc bug, where Elements == 1 for inner struct of CB array, instead of 0.
     if (CB.GetRangeSize() > 1) {
-      DXASSERT(pVarType->m_Desc.Elements == 0, "otherwise, assumption is wrong");
+      DXASSERT(pVarType->m_Desc.Elements == 0,
+               "otherwise, assumption is wrong");
+      pVarType->m_Desc.Elements = 1;
+    } else if (CB.GetGlobalSymbol()
+                   ->getType()
+                   ->getPointerElementType()
+                   ->isArrayTy() &&
+               CB.GetRangeSize() == 1) {
+      // Set elements to 1 for size 1 array.
       pVarType->m_Desc.Elements = 1;
     }
 

+ 2 - 1
lib/HLSL/DxilGenerationPass.cpp

@@ -528,7 +528,8 @@ void DxilGenerationPass::GenerateDxilCBufferHandles() {
                              DIV->getScope());
     }
 
-    if (CB.GetRangeSize() == 1) {
+    if (CB.GetRangeSize() == 1 &&
+        !GV->getType()->getElementType()->isArrayTy()) {
       Function *createHandle =
           hlslOP->GetOpFunc(OP::OpCode::CreateHandleForLib,
                             GV->getType()->getElementType());

+ 1 - 1
lib/HLSL/DxilValidation.cpp

@@ -4083,7 +4083,7 @@ CollectCBufferRanges(DxilStructAnnotation *annotation,
 
 static void ValidateCBuffer(DxilCBuffer &cb, ValidationContext &ValCtx) {
   Type *Ty = cb.GetGlobalSymbol()->getType()->getPointerElementType();
-  if (cb.GetRangeSize() != 1) {
+  if (cb.GetRangeSize() != 1 || Ty->isArrayTy()) {
     Ty = Ty->getArrayElementType();
   }
   if (!isa<StructType>(Ty)) {

+ 2 - 0
tools/clang/include/clang/AST/HlslTypes.h

@@ -322,6 +322,7 @@ clang::CXXRecordDecl* DeclareTemplateTypeWithHandle(
 
 clang::CXXRecordDecl* DeclareUIntTemplatedTypeWithHandle(
   clang::ASTContext& context, llvm::StringRef typeName, llvm::StringRef templateParamName);
+clang::CXXRecordDecl *DeclareConstantBufferViewType(clang::ASTContext& context, bool bTBuf);
 clang::CXXRecordDecl* DeclareRayQueryType(clang::ASTContext& context);
 clang::CXXRecordDecl *DeclareResourceType(clang::ASTContext &context);
 
@@ -371,6 +372,7 @@ bool IsHLSLLineStreamType(clang::QualType type);
 bool IsHLSLTriangleStreamType(clang::QualType type);
 bool IsHLSLStreamOutputType(clang::QualType type);
 bool IsHLSLResourceType(clang::QualType type);
+bool IsHLSLBufferViewType(clang::QualType type);
 bool IsHLSLNumericOrAggregateOfNumericType(clang::QualType type);
 bool IsHLSLNumericUserDefinedType(clang::QualType type);
 bool IsHLSLAggregateType(clang::QualType type);

+ 30 - 10
tools/clang/include/clang/Basic/Attr.td

@@ -905,6 +905,32 @@ def HLSLWaveSensitive : InheritableAttr {
 
 // SPIRV Change Starts
 
+// StructuredBuffer types that can have associated counters
+def CounterStructuredBuffer : SubsetSubject<
+    Var,
+    [{S->hasGlobalStorage() && S->getType()->getAs<RecordType>() &&
+      (S->getType()->getAs<RecordType>()->getDecl()->getName() == "RWStructuredBuffer" ||
+       S->getType()->getAs<RecordType>()->getDecl()->getName() == "AppendStructuredBuffer" ||
+       S->getType()->getAs<RecordType>()->getDecl()->getName() == "ConsumeStructuredBuffer")}]>;
+
+// Global variable with "ConstantBuffer" type
+def ConstantBuffer
+    : SubsetSubject<
+          Var, [{S->hasGlobalStorage() && S->getType()->getAs<RecordType>() &&
+                 S->getType()->getAs<RecordType>()->getDecl() &&
+                 S->getType()->getAs<RecordType>()->getDecl()->getName() ==
+                     "ConstantBuffer"}]>;
+
+// Global variable with "ConstantBuffer" or "TextureBuffer" type
+def ConstantTextureBuffer
+    : SubsetSubject<
+          Var, [{S->hasGlobalStorage() && S->getType()->getAs<RecordType>() &&
+                 S->getType()->getAs<RecordType>()->getDecl() &&
+                 (S->getType()->getAs<RecordType>()->getDecl()->getName() ==
+                      "ConstantBuffer" ||
+                  S->getType()->getAs<RecordType>()->getDecl()->getName() ==
+                      "TextureBuffer")}]>;
+
 def VKBuiltIn : InheritableAttr {
   let Spellings = [CXX11<"vk", "builtin">];
   let Subjects = SubjectList<[Function, ParmVar, Field], ErrorDiag>;
@@ -931,20 +957,13 @@ def VKIndex : InheritableAttr {
 
 def VKBinding : InheritableAttr {
   let Spellings = [CXX11<"vk", "binding">];
-  let Subjects = SubjectList<[GlobalVar, HLSLBuffer], ErrorDiag, "ExpectedGlobalVarOrCTBuffer">;
+  let Subjects = SubjectList<[GlobalVar, HLSLBuffer, ConstantTextureBuffer],
+                             ErrorDiag, "ExpectedGlobalVarOrCTBuffer">;
   let Args = [IntArgument<"Binding">, DefaultIntArgument<"Set", 0>];
   let LangOpts = [SPIRV];
   let Documentation = [Undocumented];
 }
 
-// StructuredBuffer types that can have associated counters
-def CounterStructuredBuffer : SubsetSubject<
-    Var,
-    [{S->hasGlobalStorage() && S->getType()->getAs<RecordType>() &&
-      (S->getType()->getAs<RecordType>()->getDecl()->getName() == "RWStructuredBuffer" ||
-       S->getType()->getAs<RecordType>()->getDecl()->getName() == "AppendStructuredBuffer" ||
-       S->getType()->getAs<RecordType>()->getDecl()->getName() == "ConsumeStructuredBuffer")}]>;
-
 def VKCounterBinding : InheritableAttr {
   let Spellings = [CXX11<"vk", "counter_binding">];
   let Subjects = SubjectList<[CounterStructuredBuffer], ErrorDiag, "ExpectedCounterStructuredBuffer">;
@@ -1006,7 +1025,8 @@ def VKPostDepthCoverage : InheritableAttr {
 
 def VKShaderRecordNV : InheritableAttr {
   let Spellings = [CXX11<"vk", "shader_record_nv">];
-  let Subjects = SubjectList<[StructGlobalVar, HLSLBuffer], ErrorDiag, "ExpectedCTBuffer">;
+  let Subjects = SubjectList<[StructGlobalVar, HLSLBuffer, ConstantBuffer],
+                             ErrorDiag, "ExpectedCTBuffer">;
   let Args = [];
   let LangOpts = [SPIRV];
   let Documentation = [Undocumented];

+ 0 - 2
tools/clang/include/clang/Basic/TokenKinds.def

@@ -483,8 +483,6 @@ KEYWORD(__super                     , KEYMS)
 // HLSL Change: HLSL-specific keywords
 KEYWORD(cbuffer                     , KEYHLSL)
 KEYWORD(tbuffer                     , KEYHLSL)
-KEYWORD(ConstantBuffer              , KEYHLSL)
-KEYWORD(TextureBuffer               , KEYHLSL)
 KEYWORD(packoffset                  , KEYHLSL)
 KEYWORD(linear                      , KEYHLSL)
 KEYWORD(centroid                    , KEYHLSL)

+ 0 - 2
tools/clang/include/clang/Parse/Parser.h

@@ -2380,8 +2380,6 @@ private:
   // HLSL Change Starts
   Decl *ParseCTBuffer(unsigned Context, SourceLocation &DeclEnd,
     ParsedAttributesWithRange &attrs, SourceLocation InlineLoc = SourceLocation());
-  Decl *ParseConstBuffer(unsigned Context, SourceLocation &DeclEnd,
-    ParsedAttributesWithRange &attrs, SourceLocation InlineLoc = SourceLocation());
   // HLSL Change Ends
 
   Decl *ParseNamespace(unsigned Context, SourceLocation &DeclEnd,

+ 11 - 2
tools/clang/include/clang/SPIRV/AstTypeProbe.h

@@ -86,8 +86,17 @@ bool isMx1Matrix(QualType type, QualType *elemType = nullptr,
 bool isMxNMatrix(QualType type, QualType *elemType = nullptr,
                  uint32_t *rowCount = nullptr, uint32_t *colCount = nullptr);
 
-/// \brief Returns true if the decl is of ConstantBuffer/TextureBuffer type.
-bool isConstantTextureBuffer(const Decl *decl);
+/// \brief Returns true if the given type is a ConstantBuffer or an array of
+/// ConstantBuffers.
+bool isConstantBuffer(QualType);
+
+/// \brief Returns true if the given type is a TextureBuffer or an array of
+/// TextureBuffers.
+bool isTextureBuffer(QualType);
+
+/// \brief Returns true if the given type is a ConstantBuffer or TextureBuffer
+/// or an array of ConstantBuffers/TextureBuffers.
+bool isConstantTextureBuffer(QualType);
 
 /// \brief Returns true if the decl will have a SPIR-V resource type.
 ///

+ 1 - 1
tools/clang/include/clang/Sema/SemaHLSL.h

@@ -156,7 +156,7 @@ bool LookupRecordMemberExprForHLSL(
   clang::SourceLocation MemberLoc,
   clang::ExprResult &result);
 
-clang::ExprResult MaybeConvertScalarToVector(
+clang::ExprResult MaybeConvertMemberAccess(
   _In_ clang::Sema* Self,
   _In_ clang::Expr* E);
 

+ 21 - 0
tools/clang/lib/AST/ASTContextHLSL.cpp

@@ -831,6 +831,27 @@ CXXRecordDecl* hlsl::DeclareUIntTemplatedTypeWithHandle(
   return typeDeclBuilder.completeDefinition();
 }
 
+clang::CXXRecordDecl *
+hlsl::DeclareConstantBufferViewType(clang::ASTContext &context, bool bTBuf) {
+  // Create ConstantBufferView template declaration in translation unit scope
+  // like other resource.
+  // template<typename T> ConstantBuffer { int h; }
+  DeclContext *DC = context.getTranslationUnitDecl();
+
+  BuiltinTypeDeclBuilder typeDeclBuilder(DC, bTBuf ? "TextureBuffer"
+                                                   : "ConstantBuffer");
+  (void)typeDeclBuilder.addTypeTemplateParam("T");
+  typeDeclBuilder.startDefinition();
+  CXXRecordDecl *templateRecordDecl = typeDeclBuilder.getRecordDecl();
+
+  typeDeclBuilder.addField(
+      "h", context.UnsignedIntTy); // Add an 'h' field to hold the handle.
+
+  typeDeclBuilder.completeDefinition();
+
+  return templateRecordDecl;
+}
+
 CXXRecordDecl* hlsl::DeclareRayQueryType(ASTContext& context) {
   // template<uint kind> RayQuery { ... }
   BuiltinTypeDeclBuilder typeDeclBuilder(context.getTranslationUnitDecl(), "RayQuery");

+ 10 - 1
tools/clang/lib/AST/HlslTypes.cpp

@@ -544,7 +544,7 @@ bool IsHLSLResourceType(clang::QualType type) {
     if (name == "SamplerState" || name == "SamplerComparisonState")
       return true;
 
-    if (name == "ConstantBuffer")
+    if (name == "ConstantBuffer" || name == "TextureBuffer")
       return true;
 
     if (name == "RaytracingAccelerationStructure")
@@ -553,6 +553,15 @@ bool IsHLSLResourceType(clang::QualType type) {
   return false;
 }
 
+bool IsHLSLBufferViewType(clang::QualType type) {
+  if (const RecordType *RT = type->getAs<RecordType>()) {
+    StringRef name = RT->getDecl()->getName();
+    if (name == "ConstantBuffer" || name == "TextureBuffer")
+      return true;
+  }
+  return false;
+}
+
 bool IsHLSLSubobjectType(clang::QualType type) {
   DXIL::SubobjectKind kind;
   DXIL::HitGroupType hgType;

+ 132 - 80
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -99,12 +99,15 @@ private:
   HLCBuffer &GetGlobalCBuffer() {
     return *static_cast<HLCBuffer*>(&(m_pHLModule->GetCBuffer(globalCBIndex)));
   }
+  void AddConstantToCB(GlobalVariable *CV, StringRef Name, QualType Ty,
+                       unsigned LowerBound, HLCBuffer &CB);
   void AddConstant(VarDecl *constDecl, HLCBuffer &CB);
   uint32_t AddSampler(VarDecl *samplerDecl);
   uint32_t AddUAVSRV(VarDecl *decl, hlsl::DxilResourceBase::Class resClass);
   bool SetUAVSRV(SourceLocation loc, hlsl::DxilResourceBase::Class resClass,
                  DxilResource *hlslRes, QualType QualTy);
   uint32_t AddCBuffer(HLSLBufferDecl *D);
+  uint32_t AddConstantBufferView(VarDecl *D);
   hlsl::DxilResourceBase::Class TypeToClass(clang::QualType Ty);
 
   void CreateSubobject(DXIL::SubobjectKind kind, const StringRef name, clang::Expr **args,
@@ -401,7 +404,7 @@ CGMSHLSLRuntime::CGMSHLSLRuntime(CodeGenModule &CGM)
            "else CGMSHLSLRuntime Constructor needs to be updated");
 
   // add globalCB
-  unique_ptr<HLCBuffer> CB = llvm::make_unique<HLCBuffer>();
+  unique_ptr<HLCBuffer> CB = llvm::make_unique<HLCBuffer>(false, false);
   std::string globalCBName = "$Globals";
   CB->SetGlobalSymbol(nullptr);
   CB->SetGlobalName(globalCBName);
@@ -2445,7 +2448,7 @@ void CGMSHLSLRuntime::addResource(Decl *D) {
       break;
     }
     case DXIL::ResourceClass::CBuffer:
-      DXASSERT(0, "cbuffer should not be here");
+      AddConstantBufferView(VD);
       break;
     }
   }
@@ -2503,7 +2506,7 @@ static DxilResourceBase::Class KeywordToClass(const std::string &keyword) {
     return DxilResourceBase::Class::CBuffer;
 
   if (keyword == "TextureBuffer")
-    return DxilResourceBase::Class::SRV;
+    return DxilResourceBase::Class::CBuffer;
 
   bool isSRV = keyword == "Buffer";
   isSRV |= keyword == "ByteAddressBuffer";
@@ -3079,6 +3082,25 @@ static bool IsResourceInType(const clang::ASTContext &context,
   return false; // no resources found
 }
 
+void CGMSHLSLRuntime::AddConstantToCB(GlobalVariable *CV, StringRef Name,
+                                      QualType Ty, unsigned LowerBound,
+                                      HLCBuffer &CB) {
+  std::unique_ptr<DxilResourceBase> pHlslConst =
+      llvm::make_unique<DxilResourceBase>(DXIL::ResourceClass::Invalid);
+  pHlslConst->SetLowerBound(LowerBound);
+  pHlslConst->SetSpaceID(0);
+  pHlslConst->SetGlobalSymbol(CV);
+  pHlslConst->SetGlobalName(Name);
+
+  DxilTypeSystem &dxilTypeSys = m_pHLModule->GetTypeSystem();
+
+  unsigned arrayEltSize = 0;
+  unsigned size = AddTypeAnnotation(Ty, dxilTypeSys, arrayEltSize);
+  pHlslConst->SetRangeSize(size);
+
+  CB.AddConst(pHlslConst);
+}
+
 void CGMSHLSLRuntime::AddConstant(VarDecl *constDecl, HLCBuffer &CB) {
   if (constDecl->getStorageClass() == SC_Static) {
     // For static inside cbuffer, take as global static.
@@ -3091,16 +3113,7 @@ void CGMSHLSLRuntime::AddConstant(VarDecl *constDecl, HLCBuffer &CB) {
     AddTypeAnnotation(constDecl->getType(), dxilTypeSys, arraySize);
     return;
   }
-  // Search defined structure for resource objects and fail
-  if (CB.GetRangeSize() > 1 &&
-      IsResourceInType(CGM.getContext(), constDecl->getType())) {
-    DiagnosticsEngine &Diags = CGM.getDiags();
-    unsigned DiagID = Diags.getCustomDiagID(
-        DiagnosticsEngine::Error,
-        "object types not supported in cbuffer/tbuffer view arrays.");
-    Diags.Report(constDecl->getLocation(), DiagID);
-    return;
-  }
+
   llvm::Constant *constVal = CGM.GetAddrOfGlobalVar(constDecl);
   // Add debug info for constVal.
   if (CGDebugInfo *DI = CGM.getModuleDebugInfo())
@@ -3176,37 +3189,16 @@ void CGMSHLSLRuntime::AddConstant(VarDecl *constDecl, HLCBuffer &CB) {
       break;
     }
   }
-
-  std::unique_ptr<DxilResourceBase> pHlslConst = llvm::make_unique<DxilResourceBase>(DXIL::ResourceClass::Invalid);
-  pHlslConst->SetLowerBound(UINT_MAX);
-  pHlslConst->SetSpaceID(0);
-  pHlslConst->SetGlobalSymbol(cast<llvm::GlobalVariable>(constVal));
-  pHlslConst->SetGlobalName(constDecl->getName());
-
-  if (userOffset) {
-    pHlslConst->SetLowerBound(offset);
-  }
   
-  DxilTypeSystem &dxilTypeSys = m_pHLModule->GetTypeSystem();
-  // Just add type annotation here.
-  // Offset will be allocated later.
-  QualType Ty = constDecl->getType();
-  if (CB.GetRangeSize() != 1) {
-    while (Ty->isArrayType()) {
-      Ty = Ty->getAsArrayTypeUnsafe()->getElementType();
-    }
-  }
-  unsigned arrayEltSize = 0;
-  unsigned size = AddTypeAnnotation(Ty, dxilTypeSys, arrayEltSize);
-  pHlslConst->SetRangeSize(size);
-
-  CB.AddConst(pHlslConst);
+  unsigned LowerBound = userOffset ? offset : UINT_MAX;
+  AddConstantToCB(cast<llvm::GlobalVariable>(constVal), constDecl->getName(),
+                  constDecl->getType(), LowerBound, CB);
 
   // Save fieldAnnotation for the const var.
   DxilFieldAnnotation fieldAnnotation;
   if (userOffset)
     fieldAnnotation.SetCBufferOffset(offset);
-
+  QualType Ty = constDecl->getType();
   // Get the nested element type.
   if (Ty->isArrayType()) {
     while (const ConstantArrayType *arrayTy =
@@ -3219,64 +3211,110 @@ void CGMSHLSLRuntime::AddConstant(VarDecl *constDecl, HLCBuffer &CB) {
   m_ConstVarAnnotationMap[constVal] = fieldAnnotation;
 }
 
-uint32_t CGMSHLSLRuntime::AddCBuffer(HLSLBufferDecl *D) {
-  unique_ptr<HLCBuffer> CB = llvm::make_unique<HLCBuffer>();
+namespace {
+unique_ptr<HLCBuffer> CreateHLCBuf(NamedDecl *D, bool bIsView, bool bIsTBuf) {
+  unique_ptr<HLCBuffer> CB = llvm::make_unique<HLCBuffer>(bIsView, bIsTBuf);
 
   // setup the CB
   CB->SetGlobalSymbol(nullptr);
   CB->SetGlobalName(D->getNameAsString());
   CB->SetSpaceID(UINT_MAX);
   CB->SetLowerBound(UINT_MAX);
-  if (!D->isCBuffer()) {
+  if (bIsTBuf)
     CB->SetKind(DXIL::ResourceKind::TBuffer);
-  }
+  InitFromUnusualAnnotations(*CB, *D);
 
-  // the global variable will only used once by the createHandle?
-  // SetHandle(llvm::Value *pHandle);
+  return CB;
+}
 
-  InitFromUnusualAnnotations(*CB, *D);
+bool IsTextureBufferViewName(StringRef keyword) {
+  return keyword == "TextureBuffer";
+}
+
+bool IsTextureBufferView(clang::QualType Ty, clang::ASTContext &context) {
+  Ty = Ty.getCanonicalType();
+  if (const clang::ArrayType *arrayType = context.getAsArrayType(Ty)) {
+    return IsTextureBufferView(arrayType->getElementType(), context);
+  } else if (const RecordType *RT = Ty->getAsStructureType()) {
+    return IsTextureBufferViewName(RT->getDecl()->getName());
+  } else if (const RecordType *RT = Ty->getAs<RecordType>()) {
+    if (const ClassTemplateSpecializationDecl *templateDecl =
+            dyn_cast<ClassTemplateSpecializationDecl>(RT->getDecl())) {
+      return IsTextureBufferViewName(templateDecl->getName());
+    }
+  }
+  return false;
+}
+} // namespace
+
+uint32_t CGMSHLSLRuntime::AddCBuffer(HLSLBufferDecl *D) {
+  unique_ptr<HLCBuffer> CB = CreateHLCBuf(D, false, !D->isCBuffer());
 
   // Add constant
-  if (D->isConstantBufferView()) {
-    VarDecl *constDecl = cast<VarDecl>(*D->decls_begin());
-    CB->SetRangeSize(1);
-    QualType Ty = constDecl->getType();
-    if (Ty->isArrayType()) {
-      if (!Ty->isIncompleteArrayType()) {
-        unsigned arraySize = 1;
-        while (Ty->isArrayType()) {
-          Ty = Ty->getCanonicalTypeUnqualified();
-          const ConstantArrayType *AT = cast<ConstantArrayType>(Ty);
-          arraySize *= AT->getSize().getLimitedValue();
-          Ty = AT->getElementType();
-        }
-        CB->SetRangeSize(arraySize);
-      } else {
-        CB->SetRangeSize(UINT_MAX);
-      }
+  auto declsEnds = D->decls_end();
+  CB->SetRangeSize(1);
+  for (auto it = D->decls_begin(); it != declsEnds; it++) {
+    if (VarDecl *constDecl = dyn_cast<VarDecl>(*it)) {
+      AddConstant(constDecl, *CB.get());
+    } else if (isa<EmptyDecl>(*it)) {
+      // Nothing to do for this declaration.
+    } else if (isa<CXXRecordDecl>(*it)) {
+      // Nothing to do for this declaration.
+    } else if (isa<FunctionDecl>(*it)) {
+      // A function within an cbuffer is effectively a top-level function,
+      // as it only refers to globally scoped declarations.
+      this->CGM.EmitTopLevelDecl(*it);
+    } else {
+      HLSLBufferDecl *inner = cast<HLSLBufferDecl>(*it);
+      GetOrCreateCBuffer(inner);
     }
-    AddConstant(constDecl, *CB.get());
-  } else {
-    auto declsEnds = D->decls_end();
-    CB->SetRangeSize(1);
-    for (auto it = D->decls_begin(); it != declsEnds; it++) {
-      if (VarDecl *constDecl = dyn_cast<VarDecl>(*it)) {
-        AddConstant(constDecl, *CB.get());
-      } else if (isa<EmptyDecl>(*it)) {
-        // Nothing to do for this declaration.
-      } else if (isa<CXXRecordDecl>(*it)) {
-        // Nothing to do for this declaration.
-      } else if (isa<FunctionDecl>(*it)) {
-        // A function within an cbuffer is effectively a top-level function,
-        // as it only refers to globally scoped declarations.
-        this->CGM.EmitTopLevelDecl(*it);
-      } else {
-        HLSLBufferDecl *inner = cast<HLSLBufferDecl>(*it);
-        GetOrCreateCBuffer(inner);
+  }
+
+  CB->SetID(m_pHLModule->GetCBuffers().size());
+  return m_pHLModule->AddCBuffer(std::move(CB));
+}
+
+uint32_t CGMSHLSLRuntime::AddConstantBufferView(VarDecl *D) {
+  QualType Ty = D->getType();
+  unique_ptr<HLCBuffer> CB = CreateHLCBuf(D, true, IsTextureBufferView(Ty, CGM.getContext()));
+
+  CB->SetRangeSize(1);
+
+  if (Ty->isArrayType()) {
+    if (!Ty->isIncompleteArrayType()) {
+      unsigned arraySize = 1;
+      while (Ty->isArrayType()) {
+        Ty = Ty->getCanonicalTypeUnqualified();
+        const ConstantArrayType *AT = cast<ConstantArrayType>(Ty);
+        arraySize *= AT->getSize().getLimitedValue();
+        Ty = AT->getElementType();
       }
+      CB->SetRangeSize(arraySize);
+    } else {
+      Ty = QualType(Ty->getArrayElementTypeNoTypeQual(),0);
+      CB->SetRangeSize(UINT_MAX);
     }
+    CB->SetIsArray();
   }
 
+  QualType ResultTy = hlsl::GetHLSLResourceResultType(Ty);
+
+    // Search defined structure for resource objects and fail
+  if (CB->GetRangeSize() > 1 && IsResourceInType(CGM.getContext(), ResultTy)) {
+    DiagnosticsEngine &Diags = CGM.getDiags();
+    unsigned DiagID = Diags.getCustomDiagID(
+        DiagnosticsEngine::Error,
+        "object types not supported in cbuffer/tbuffer view arrays.");
+    Diags.Report(D->getLocation(), DiagID);
+    return UINT_MAX;
+  }
+  // Not allow offset for CBV.
+  unsigned LowerBound = 0;
+
+  GlobalVariable *GV = cast<GlobalVariable>(CGM.GetAddrOfGlobalVar(D));
+  AddConstantToCB(GV, D->getName(), ResultTy, LowerBound, *CB.get());
+
+  CB->SetResultType(CGM.getTypes().ConvertType(ResultTy));
   CB->SetID(m_pHLModule->GetCBuffers().size());
   return m_pHLModule->AddCBuffer(std::move(CB));
 }
@@ -5229,6 +5267,20 @@ void CGMSHLSLRuntime::EmitHLSLFlatConversionAggregateCopy(CodeGenFunction &CGF,
     // Store to resource ptr.
     CGF.Builder.CreateStore(V, DestPtr);
     return;
+  } else if (GetResourceClassForType(CGM.getContext(), SrcTy) ==
+             DXIL::ResourceClass::CBuffer) {
+    llvm::Type *ResultTy =
+        CGM.getTypes().ConvertType(hlsl::GetHLSLResourceResultType(SrcTy));
+    if (ResultTy == DestPtrTy) {
+      // Cast ConstantBuffer to result type then copy.
+      Value *Cast = CGF.Builder.CreateBitCast(
+          SrcPtr,
+          ResultTy->getPointerTo(DestPtr->getType()->getPointerAddressSpace()));
+      unsigned size = TheModule.getDataLayout().getTypeAllocSize(
+          DestPtrTy);
+      CGF.Builder.CreateMemCpy(DestPtr, Cast, size, 1);
+      return;
+    }
   } else if (dxilutil::IsHLSLObjectType(dxilutil::GetArrayEltTy(SrcPtrTy)) &&
              dxilutil::IsHLSLObjectType(dxilutil::GetArrayEltTy(DestPtrTy))) {
     unsigned sizeSrc = TheModule.getDataLayout().getTypeAllocSize(SrcPtrTy);

+ 26 - 12
tools/clang/lib/CodeGen/CGHLSLMSFinishCodeGen.cpp

@@ -1742,10 +1742,14 @@ void ReplaceUseInFunction(Value *V, Value *NewV, Function *F,
     if (Instruction *I = dyn_cast<Instruction>(user)) {
       if (I->getParent()->getParent() == F) {
         // replace use with GEP if in F
-        for (unsigned i = 0; i < I->getNumOperands(); i++) {
-          if (I->getOperand(i) == V)
-            I->setOperand(i, NewV);
+        if (BitCastInst *BCI = dyn_cast<BitCastInst>(I)) {
+          if (BCI->getType() == NewV->getType()) {
+            I->replaceAllUsesWith(NewV);
+            I->eraseFromParent();
+            continue;
+          }
         }
+        I->replaceUsesOfWith(V, NewV);
       }
     } else {
       // For constant operator, create local clone which use GEP.
@@ -1798,7 +1802,7 @@ bool CreateCBufferVariable(HLCBuffer &CB, HLModule &HLM, llvm::Type *HandleTy) {
   SmallVector<llvm::Type *, 4> Elements;
   for (const std::unique_ptr<DxilResourceBase> &C : CB.GetConstants()) {
     Value *GV = C->GetGlobalSymbol();
-    if (GV->hasNUsesOrMore(1))
+    if (!GV->use_empty())
       bUsed = true;
     // Global variable must be pointer type.
     llvm::Type *Ty = GV->getType()->getPointerElementType();
@@ -1810,18 +1814,28 @@ bool CreateCBufferVariable(HLCBuffer &CB, HLModule &HLM, llvm::Type *HandleTy) {
 
   llvm::Module &M = *HLM.GetModule();
 
-  bool isCBArray = CB.GetRangeSize() != 1;
+  bool isCBArray = CB.IsArray();
   llvm::GlobalVariable *cbGV = nullptr;
   llvm::Type *cbTy = nullptr;
 
   unsigned cbIndexDepth = 0;
   if (!isCBArray) {
-    llvm::StructType *CBStructTy =
-        llvm::StructType::create(Elements, CB.GetGlobalName());
-    cbGV = new llvm::GlobalVariable(M, CBStructTy, /*IsConstant*/ true,
-                                    llvm::GlobalValue::ExternalLinkage,
-                                    /*InitVal*/ nullptr, CB.GetGlobalName());
-    cbTy = cbGV->getType();
+    if (CB.IsView()) {
+      llvm::StructType *CBStructTy =
+          llvm::StructType::create(CB.GetResultType(), CB.GetGlobalName());
+      cbGV = new llvm::GlobalVariable(M, CBStructTy,
+                                      /*IsConstant*/ true,
+                                      llvm::GlobalValue::ExternalLinkage,
+                                      /*InitVal*/ nullptr, CB.GetGlobalName());
+      cbTy = cbGV->getType();
+    } else {
+      llvm::StructType *CBStructTy =
+          llvm::StructType::create(Elements, CB.GetGlobalName());
+      cbGV = new llvm::GlobalVariable(M, CBStructTy, /*IsConstant*/ true,
+                                      llvm::GlobalValue::ExternalLinkage,
+                                      /*InitVal*/ nullptr, CB.GetGlobalName());
+      cbTy = cbGV->getType();
+    }
   } else {
     // For array of ConstantBuffer, create array of struct instead of struct of
     // array.
@@ -1838,7 +1852,7 @@ bool CreateCBufferVariable(HLCBuffer &CB, HLModule &HLM, llvm::Type *HandleTy) {
 
     // Add one level struct type to match normal case.
     llvm::StructType *CBStructTy =
-        llvm::StructType::create({CBEltTy}, CB.GetGlobalName());
+        llvm::StructType::create({CB.GetResultType()}, CB.GetGlobalName());
 
     llvm::ArrayType *CBArrayTy =
         llvm::ArrayType::get(CBStructTy, CB.GetRangeSize());

+ 13 - 2
tools/clang/lib/CodeGen/CGHLSLMSHelper.h

@@ -58,7 +58,8 @@ struct PatchConstantInfo {
 /// Use this class to represent HLSL cbuffer in high-level DXIL.
 class HLCBuffer : public hlsl::DxilCBuffer {
 public:
-  HLCBuffer() = default;
+  HLCBuffer(bool bIsView, bool bIsTBuf)
+      : bIsView(bIsView), bIsTBuf(bIsTBuf), bIsArray(false), ResultTy(nullptr) {}
   virtual ~HLCBuffer() = default;
 
   void AddConst(std::unique_ptr<DxilResourceBase> &pItem) {
@@ -70,11 +71,21 @@ public:
     return constants;
   }
 
+  bool IsView() { return bIsView; }
+  bool IsTBuf() { return bIsTBuf; }
+  bool IsArray() { return bIsArray; }
+  void SetIsArray() { bIsArray = true; }
+  llvm::Type *GetResultType() { return ResultTy; }
+  void SetResultType(llvm::Type *Ty) { ResultTy = Ty; }
+
 private:
   std::vector<std::unique_ptr<DxilResourceBase>>
       constants; // constants inside const buffer
+  bool bIsView;
+  bool bIsTBuf;
+  bool bIsArray;
+  llvm::Type *ResultTy;
 };
-
 // Scope to help transform multiple returns.
 struct Scope {
  enum class ScopeKind {

+ 0 - 4
tools/clang/lib/Parse/ParseDecl.cpp

@@ -1950,10 +1950,6 @@ Parser::DeclGroupPtrTy Parser::ParseDeclaration(unsigned Context,
   case tok::kw_tbuffer:
     SingleDecl = ParseCTBuffer(Context, DeclEnd, attrs);
     break;
-  case tok::kw_ConstantBuffer:
-  case tok::kw_TextureBuffer:
-    SingleDecl = ParseConstBuffer(Context, DeclEnd, attrs);
-    break;
   // HLSL Change Ends
   case tok::kw_namespace:
     ProhibitAttributes(attrs);

+ 0 - 54
tools/clang/lib/Parse/ParseHLSL.cpp

@@ -74,60 +74,6 @@ Decl *Parser::ParseCTBuffer(unsigned Context, SourceLocation &DeclEnd,
   return decl;
 }
 
-Decl *Parser::ParseConstBuffer(unsigned Context, SourceLocation &DeclEnd,
-                               ParsedAttributesWithRange &attrs,
-                               SourceLocation InlineLoc) {
-  bool isCBuffer = Tok.is(tok::kw_ConstantBuffer);
-  assert((isCBuffer || Tok.is(tok::kw_TextureBuffer)) && "Not a ConstantBuffer or TextureBuffer!");
-
-  SourceLocation BufferLoc = ConsumeToken(); // eat the 'ConstantBuffer'.
-
-  if (!Tok.is(tok::less)) {
-    Diag(Tok, diag::err_expected) << tok::less;
-    return nullptr;
-  }
-  ConsumeToken(); // eat the <
-
-  ParsingDeclSpec PDS(*this);
-  ParseDeclarationSpecifiers(PDS);
-
-  if (!Tok.is(tok::greater)) {
-    Diag(Tok, diag::err_expected) << tok::greater;
-    return nullptr;
-  }
-  ConsumeToken(); // eat the >
-
-  PDS.takeAttributesFrom(attrs);
-
-  Actions.ActOnStartHLSLBufferView();
-  Parser::DeclGroupPtrTy dcl = ParseDeclGroup(PDS, Declarator::FileContext);
-
-  // If parsing of decl group fails, then decl group must have been illformed. Bail out!
-  // Note that we don't have to generate any diagnostics here as it was already
-  // generated previously in ParseDirectDeclarator().
-  if (!dcl)
-    return nullptr;
-
-  // Check if the register type is valid
-  NamedDecl *namedDecl = cast<NamedDecl>(dcl.get().getSingleDecl());
-  ArrayRef<hlsl::UnusualAnnotation*> annotations = namedDecl->getUnusualAnnotations();
-  for (hlsl::UnusualAnnotation* annotation : annotations) {
-    if (const auto *regAssignment = dyn_cast<hlsl::RegisterAssignment>(annotation)) {
-      if (regAssignment->isSpaceOnly())
-        continue;
-      if (isCBuffer && regAssignment->RegisterType != 'b' && regAssignment->RegisterType != 'B') {
-        Diag(namedDecl->getLocation(), diag::err_hlsl_incorrect_bind_semantic) << "'b'";
-      }
-      else if (!isCBuffer && regAssignment->RegisterType != 't' && regAssignment->RegisterType != 'T') {
-        Diag(namedDecl->getLocation(), diag::err_hlsl_incorrect_bind_semantic) << "'t'";
-      }
-    }
-  }
-
-  Decl *decl = Actions.ActOnHLSLBufferView(getCurScope(), BufferLoc, dcl, isCBuffer);
-
-  return decl;
-}
 
 /// ParseHLSLAttributeSpecifier - Parse an HLSL attribute-specifier. 
 ///

+ 0 - 2
tools/clang/lib/Parse/Parser.cpp

@@ -766,8 +766,6 @@ Parser::ParseExternalDeclaration(ParsedAttributesWithRange &attrs,
   // HLSL Change Starts: Start parsing declaration of cbuffer and tbuffers
   case tok::kw_cbuffer:
   case tok::kw_tbuffer:
-  case tok::kw_ConstantBuffer:
-  case tok::kw_TextureBuffer:
     // HLSL Change Ends
   case tok::kw_using:
   case tok::kw_namespace:

+ 23 - 8
tools/clang/lib/SPIRV/AstTypeProbe.cpp

@@ -265,18 +265,33 @@ bool isSubpassInputMS(QualType type) {
   return false;
 }
 
-bool isConstantTextureBuffer(const Decl *decl) {
-  if (const auto *bufferDecl = dyn_cast<HLSLBufferDecl>(decl->getDeclContext()))
-    // Make sure we are not returning true for VarDecls inside cbuffer/tbuffer.
-    return bufferDecl->isConstantBufferView();
+bool isConstantBuffer(clang::QualType type) {
+  // Strip outer arrayness first
+  while (type->isArrayType())
+    type = type->getAsArrayTypeUnsafe()->getElementType();
+  if (const RecordType *RT = type->getAs<RecordType>()) {
+    StringRef name = RT->getDecl()->getName();
+    return name == "ConstantBuffer";
+  }
+  return false;
+}
 
+bool isTextureBuffer(clang::QualType type) {
+  // Strip outer arrayness first
+  while (type->isArrayType())
+    type = type->getAsArrayTypeUnsafe()->getElementType();
+  if (const RecordType *RT = type->getAs<RecordType>()) {
+    StringRef name = RT->getDecl()->getName();
+    return name == "TextureBuffer";
+  }
   return false;
 }
 
-bool isResourceType(const ValueDecl *decl) {
-  if (isConstantTextureBuffer(decl))
-    return true;
+bool isConstantTextureBuffer(QualType type) {
+  return isConstantBuffer(type) || isTextureBuffer(type);
+}
 
+bool isResourceType(const ValueDecl *decl) {
   QualType declType = decl->getType();
 
   // Deprive the arrayness to see the element type
@@ -1306,7 +1321,7 @@ bool isStructureContainingAnyKindOfBuffer(QualType type) {
       while (fieldType->isArrayType())
         fieldType = fieldType->getAsArrayTypeUnsafe()->getElementType();
       if (isAKindOfStructuredOrByteBuffer(fieldType) ||
-          isConstantTextureBuffer(field) ||
+          isConstantTextureBuffer(fieldType) ||
           isStructureContainingAnyKindOfBuffer(fieldType)) {
         return true;
       }

+ 45 - 33
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -212,17 +212,6 @@ bool shouldSkipInStructLayout(const Decl *decl) {
 
   const auto *declContext = decl->getDeclContext();
 
-  // Special check for ConstantBuffer/TextureBuffer, whose DeclContext is a
-  // HLSLBufferDecl. So that we need to check the HLSLBufferDecl's parent decl
-  // to check whether this is a ConstantBuffer/TextureBuffer defined in the
-  // global namespace.
-  // Note that we should not be seeing ConstantBuffer/TextureBuffer for normal
-  // cbuffer/tbuffer or push constant blocks. So this case should only happen
-  // for $Globals cbuffer.
-  if (isConstantTextureBuffer(decl) &&
-      declContext->getLexicalParent()->isTranslationUnit())
-    return true;
-
   // $Globals' "struct" is the TranslationUnit, so we should ignore resources
   // in the TranslationUnit "struct" and its child namespaces.
   if (declContext->isTranslationUnit() || declContext->isNamespace()) {
@@ -999,6 +988,7 @@ void DeclResultIdMapper::createEnumConstant(const EnumConstantDecl *decl) {
 }
 
 SpirvVariable *DeclResultIdMapper::createCTBuffer(const HLSLBufferDecl *decl) {
+  // This function handles creation of cbuffer or tbuffer.
   const auto usageKind =
       decl->isCBuffer() ? ContextUsageKind::CBuffer : ContextUsageKind::TBuffer;
   const std::string structName = "type." + decl->getName().str();
@@ -1032,44 +1022,67 @@ SpirvVariable *DeclResultIdMapper::createCTBuffer(const HLSLBufferDecl *decl) {
 }
 
 SpirvVariable *DeclResultIdMapper::createCTBuffer(const VarDecl *decl) {
+  // This function handles creation of ConstantBuffer<T> or TextureBuffer<T>.
+  // The way this is represented in the AST is as follows:
+  //
+  // |-VarDecl MyCbuffer 'ConstantBuffer<T>':'ConstantBuffer<T>'
+  // |-CXXRecordDecl referenced struct T definition
+  //   |-CXXRecordDecl implicit struct T
+  //   |-FieldDecl
+  //   |-...
+  //   |-FieldDecl
+
+  const QualType type = decl->getType();
+  assert(isConstantTextureBuffer(type));
   const RecordType *recordType = nullptr;
+  const RecordType *templatedType = nullptr;
   int arraySize = 0;
 
   // In case we have an array of ConstantBuffer/TextureBuffer:
-  if (const auto *arrayType = decl->getType()->getAsArrayTypeUnsafe()) {
-    recordType = arrayType->getElementType()->getAs<RecordType>();
-    if (const auto *caType =
-            astContext.getAsConstantArrayType(decl->getType())) {
+  if (const auto *arrayType = type->getAsArrayTypeUnsafe()) {
+    const QualType elemType = arrayType->getElementType();
+    recordType = elemType->getAs<RecordType>();
+    templatedType =
+        hlsl::GetHLSLResourceResultType(elemType)->getAs<RecordType>();
+    if (const auto *caType = astContext.getAsConstantArrayType(type)) {
       arraySize = static_cast<uint32_t>(caType->getSize().getZExtValue());
     } else {
       arraySize = -1;
     }
   } else {
-    recordType = decl->getType()->getAs<RecordType>();
+    recordType = type->getAs<RecordType>();
+    templatedType = hlsl::GetHLSLResourceResultType(type)->getAs<RecordType>();
   }
   if (!recordType) {
     emitError("constant/texture buffer type %0 unimplemented",
               decl->getLocStart())
-        << decl->getType();
-    return 0;
+        << type;
+    return nullptr;
+  }
+  if (!templatedType) {
+    emitError(
+        "the underlying type for constant/texture buffer must be a struct",
+        decl->getLocStart())
+        << type;
+    return nullptr;
   }
 
-  const auto *context = cast<HLSLBufferDecl>(decl->getDeclContext());
-  const auto usageKind = context->isCBuffer() ? ContextUsageKind::CBuffer
-                                              : ContextUsageKind::TBuffer;
+  const bool isConstBuffer = isConstantBuffer(type);
+  const auto usageKind =
+      isConstBuffer ? ContextUsageKind::CBuffer : ContextUsageKind::TBuffer;
 
-  const char *ctBufferName =
-      context->isCBuffer() ? "ConstantBuffer." : "TextureBuffer.";
-  const std::string structName = "type." + std::string(ctBufferName) +
-                                 recordType->getDecl()->getName().str();
+  const std::string structName = "type." +
+                                 recordType->getDecl()->getName().str() + "." +
+                                 templatedType->getDecl()->getName().str();
 
   SpirvVariable *bufferVar = createStructOrStructArrayVarOfExplicitLayout(
-      recordType->getDecl(), arraySize, usageKind, structName, decl->getName());
+      templatedType->getDecl(), arraySize, usageKind, structName,
+      decl->getName());
 
   // We register the VarDecl here.
   astDecls[decl] = DeclSpirvInfo(bufferVar);
   resourceVars.emplace_back(
-      bufferVar, decl, decl->getLocation(), getResourceBinding(context),
+      bufferVar, decl, decl->getLocation(), getResourceBinding(decl),
       decl->getAttr<VKBindingAttr>(), decl->getAttr<VKCounterBindingAttr>());
 
   return bufferVar;
@@ -1097,7 +1110,8 @@ SpirvVariable *DeclResultIdMapper::createPushConstant(const VarDecl *decl) {
 
 SpirvVariable *
 DeclResultIdMapper::createShaderRecordBufferNV(const VarDecl *decl) {
-  const auto *recordType = decl->getType()->getAs<RecordType>();
+  const auto *recordType =
+      hlsl::GetHLSLResourceResultType(decl->getType())->getAs<RecordType>();
   assert(recordType);
 
   const std::string structName =
@@ -3513,11 +3527,9 @@ QualType DeclResultIdMapper::getTypeAndCreateCounterForPotentialAliasVar(
   // Whether we should generate this decl as an alias variable.
   bool genAlias = false;
 
-  if (const auto *buffer = dyn_cast<HLSLBufferDecl>(decl->getDeclContext())) {
-    // For ConstantBuffer and TextureBuffer
-    if (buffer->isConstantBufferView())
-      genAlias = true;
-  } else if (isOrContainsAKindOfStructuredOrByteBuffer(type)) {
+  // For ConstantBuffers, TextureBuffers, StructuredBuffers, ByteAddressBuffers
+  if (isConstantTextureBuffer(type) ||
+      isOrContainsAKindOfStructuredOrByteBuffer(type)) {
     genAlias = true;
   }
 

+ 50 - 23
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -113,8 +113,10 @@ inline bool isExternalVar(const VarDecl *var) {
 const DeclContext *isConstantTextureBufferDeclRef(const Expr *expr) {
   if (const auto *declRefExpr = dyn_cast<DeclRefExpr>(expr->IgnoreParenCasts()))
     if (const auto *varDecl = dyn_cast<VarDecl>(declRefExpr->getFoundDecl()))
-      if (isConstantTextureBuffer(varDecl))
-        return varDecl->getType()->getAs<RecordType>()->getDecl();
+      if (isConstantTextureBuffer(varDecl->getType()))
+        return hlsl::GetHLSLResourceResultType(varDecl->getType())
+            ->getAs<RecordType>()
+            ->getDecl();
 
   return nullptr;
 }
@@ -283,18 +285,6 @@ inline bool evaluatesToConstZero(const Expr *expr, ASTContext &astContext) {
   return false;
 }
 
-/// Returns the HLSLBufferDecl if the given VarDecl is inside a cbuffer/tbuffer.
-/// Returns nullptr otherwise, including varDecl is a ConstantBuffer or
-/// TextureBuffer itself.
-inline const HLSLBufferDecl *getCTBufferContext(const VarDecl *varDecl) {
-  if (const auto *bufferDecl =
-          dyn_cast<HLSLBufferDecl>(varDecl->getDeclContext()))
-    // Filter ConstantBuffer/TextureBuffer
-    if (!bufferDecl->isConstantBufferView())
-      return bufferDecl;
-  return nullptr;
-}
-
 /// Returns the real definition of the callee of the given CallExpr.
 ///
 /// If we are calling a forward-declared function, callee will be the
@@ -657,14 +647,7 @@ void SpirvEmitter::doDecl(const Decl *decl) {
   }
 
   if (const auto *varDecl = dyn_cast<VarDecl>(decl)) {
-    // We can have VarDecls inside cbuffer/tbuffer. For those VarDecls, we need
-    // to emit their cbuffer/tbuffer as a whole and access each individual one
-    // using access chains.
-    if (const auto *bufferDecl = getCTBufferContext(varDecl)) {
-      doHLSLBufferDecl(bufferDecl);
-    } else {
       doVarDecl(varDecl);
-    }
   } else if (const auto *namespaceDecl = dyn_cast<NamespaceDecl>(decl)) {
     for (auto *subDecl : namespaceDecl->decls())
       // Note: We only emit functions as they are discovered through the call
@@ -1222,6 +1205,8 @@ bool SpirvEmitter::validateVKAttributes(const NamedDecl *decl) {
       isValidType = bufDecl->isCBuffer();
     else if ((bufDecl = dyn_cast<HLSLBufferDecl>(decl->getDeclContext())))
       isValidType = bufDecl->isCBuffer();
+    else if(isa<VarDecl>(decl))
+      isValidType = isConstantBuffer(dyn_cast<VarDecl>(decl)->getType());
 
     if (!isValidType) {
       emitError(
@@ -1352,8 +1337,20 @@ void SpirvEmitter::doVarDecl(const VarDecl *decl) {
     return;
   }
 
-  if (isa<HLSLBufferDecl>(decl->getDeclContext())) {
-    // This is a VarDecl of a ConstantBuffer/TextureBuffer type.
+  // We can have VarDecls inside cbuffer/tbuffer. For those VarDecls, we need
+  // to emit their cbuffer/tbuffer as a whole and access each individual one
+  // using access chains.
+  // cbuffers and tbuffers are HLSLBufferDecls
+  // ConstantBuffers and TextureBuffers are not HLSLBufferDecls.
+  if (const auto *bufferDecl =
+          dyn_cast<HLSLBufferDecl>(decl->getDeclContext())) {
+    // This is a VarDecl of cbuffer/tbuffer type.
+    doHLSLBufferDecl(bufferDecl);
+    return;
+  }
+
+  if (isConstantTextureBuffer(decl->getType())) {
+    // This is a VarDecl of ConstantBuffer/TextureBuffer type.
     (void)declIdMapper.createCTBuffer(decl);
     return;
   }
@@ -2671,6 +2668,16 @@ SpirvInstruction *SpirvEmitter::doCastExpr(const CastExpr *expr) {
 SpirvInstruction *SpirvEmitter::processFlatConversion(
     const QualType type, const QualType initType, SpirvInstruction *initInstr,
     SourceLocation srcLoc) {
+  // When translating ConstantBuffer<T> or TextureBuffer<T> types, we consider
+  // the underlying type (T), and therefore we should bypass the FlatConversion
+  // node when accessing these types:
+  // `-MemberExpr
+  //   `-ImplicitCastExpr 'const T' lvalue <FlatConversion>
+  //     `-ArraySubscriptExpr 'ConstantBuffer<T>':'ConstantBuffer<T>' lvalue
+  if(isConstantTextureBuffer(initType)) {
+    return initInstr;
+  }
+
   // Try to translate the canonical type first
   const auto canonicalType = type.getCanonicalType();
   if (canonicalType != type)
@@ -6758,6 +6765,26 @@ const Expr *SpirvEmitter::collectArrayStructIndices(
     }
   }
 
+  {
+    // Indexing into ConstantBuffers and TextureBuffers involves an additional
+    // FlatConversion node which casts the handle to the underlying structure
+    // type. We can look past the FlatConversion to continue to collect indices.
+    // For example: MyConstantBufferArray[0].structMember1
+    // `-MemberExpr .structMember1
+    //   `-ImplicitCastExpr 'const T' lvalue <FlatConversion>
+    //     `-ArraySubscriptExpr 'ConstantBuffer<T>':'ConstantBuffer<T>' lvalue
+    if (auto *castExpr = dyn_cast<ImplicitCastExpr>(expr)) {
+      if (castExpr->getCastKind() == CK_FlatConversion) {
+        const auto *subExpr = castExpr->getSubExpr();
+        const QualType subExprType = subExpr->getType();
+        if (isConstantTextureBuffer(subExprType)) {
+          return collectArrayStructIndices(subExpr, rawIndex, rawIndices,
+                                           indices, isMSOutAttribute);
+        }
+      }
+    }
+  }
+
   // This the deepest we can go. No more array or struct indexing.
   return expr;
 }

+ 1 - 1
tools/clang/lib/Sema/SemaExprMember.cpp

@@ -1656,7 +1656,7 @@ ExprResult Sema::ActOnMemberAccessExpr(Scope *S, Expr *Base,
 
   // HLSL Changes Start
   if (getLangOpts().HLSL) {
-    Result = hlsl::MaybeConvertScalarToVector(this, Base);
+    Result = hlsl::MaybeConvertMemberAccess(this, Base);
     if (Result.isInvalid()) return ExprError();
     Base = Result.get();
   }

+ 61 - 5
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -1278,6 +1278,9 @@ const ArBasicKind g_ArBasicKindsAsTypes[] =
   //AR_OBJECT_SAMPLERCUBE,
   AR_OBJECT_SAMPLERCOMPARISON,
 
+  AR_OBJECT_CONSTANT_BUFFER,
+  AR_OBJECT_TEXTURE_BUFFER,
+
   AR_OBJECT_POINTSTREAM,
   AR_OBJECT_LINESTREAM,
   AR_OBJECT_TRIANGLESTREAM,
@@ -1366,6 +1369,9 @@ const uint8_t g_ArBasicKindsTemplateCount[] =
   //AR_OBJECT_SAMPLERCUBE,
   0, // AR_OBJECT_SAMPLERCOMPARISON
 
+  1, //AR_OBJECT_CONSTANT_BUFFER,
+  1, //AR_OBJECT_TEXTURE_BUFFER,
+
   1, // AR_OBJECT_POINTSTREAM
   1, // AR_OBJECT_LINESTREAM
   1, // AR_OBJECT_TRIANGLESTREAM
@@ -1462,6 +1468,9 @@ const SubscriptOperatorRecord g_ArBasicKindsSubscripts[] =
   //AR_OBJECT_SAMPLERCUBE,
   { 0, MipsFalse, SampleFalse }, // AR_OBJECT_SAMPLERCOMPARISON (SamplerComparison)
 
+  { 0, MipsFalse, SampleFalse }, // AR_OBJECT_CONSTANT_BUFFER
+  { 0, MipsFalse, SampleFalse }, // AR_OBJECT_TEXTURE_BUFFER
+
   { 0, MipsFalse, SampleFalse }, // AR_OBJECT_POINTSTREAM (PointStream)
   { 0, MipsFalse, SampleFalse }, // AR_OBJECT_LINESTREAM (LineStream)
   { 0, MipsFalse, SampleFalse }, // AR_OBJECT_TRIANGLESTREAM (TriangleStream)
@@ -3390,6 +3399,10 @@ private:
           recordDecl = CreateSubobjectRaytracingPipelineConfig1(*m_context);
           break;
         }
+      } else if (kind == AR_OBJECT_CONSTANT_BUFFER) {
+        recordDecl = DeclareConstantBufferViewType(*m_context, /*bTBuf*/false);
+      } else if (kind == AR_OBJECT_TEXTURE_BUFFER) {
+        recordDecl = DeclareConstantBufferViewType(*m_context, /*bTBuf*/true);
       } else if (kind == AR_OBJECT_RAY_QUERY) {
         recordDecl = DeclareRayQueryType(*m_context);
       } else if (kind == AR_OBJECT_RESOURCE) {
@@ -4563,6 +4576,32 @@ public:
     if (templateName.equals(StringRef("is_same"))) {
       return false;
     }
+    // Allow object type for Constant/TextureBuffer.
+    if (templateName == "ConstantBuffer" || templateName == "TextureBuffer") {
+      if (TemplateArgList.size() == 1) {
+        const TemplateArgumentLoc &argLoc = TemplateArgList[0];
+        const TemplateArgument &arg = argLoc.getArgument();
+        DXASSERT(arg.getKind() == TemplateArgument::ArgKind::Type, "");
+        QualType argType = arg.getAsType();
+        SourceLocation argSrcLoc = argLoc.getLocation();
+        if (IsScalarType(argType) || IsVectorType(m_sema, argType) ||
+            IsMatrixType(m_sema, argType) || argType->isArrayType()) {
+          m_sema->Diag(argSrcLoc,
+                       diag::err_hlsl_typeintemplateargument_requires_struct)
+              << argType;
+          return true;
+        }
+        if (const RecordType* recordType = argType->getAsStructureType()) {
+          if (!recordType->getDecl()->isCompleteDefinition()) {
+            m_sema->Diag(argSrcLoc, diag::err_typecheck_decl_incomplete_type)
+                << argType;
+
+            return true;
+          }
+        }
+      }
+      return false;
+    }
 
     bool isMatrix = Template->getCanonicalDecl() ==
                     m_matrixTemplateDecl->getCanonicalDecl();
@@ -4714,10 +4753,11 @@ public:
     SourceLocation OpLoc,
     SourceLocation MemberLoc);
 
-  /// <summary>If E is a scalar, converts it to a 1-element vector.</summary>
+  /// <summary>If E is a scalar, converts it to a 1-element vector. If E is a
+  /// Constant/TextureBuffer<T>, converts it to const T.</summary>
   /// <param name="E">Expression to convert.</param>
   /// <returns>The result of the conversion; or E if the type is not a scalar.</returns>
-  ExprResult MaybeConvertScalarToVector(_In_ clang::Expr* E);
+  ExprResult MaybeConvertMemberAccess(_In_ clang::Expr* E);
 
   clang::Expr *HLSLImpCastToScalar(
     _In_ clang::Sema* self,
@@ -7734,8 +7774,16 @@ ExprResult HLSLExternalSource::LookupArrayMemberExprForHLSL(
 }
   
 
-ExprResult HLSLExternalSource::MaybeConvertScalarToVector(_In_ clang::Expr* E) {
+ExprResult HLSLExternalSource::MaybeConvertMemberAccess(_In_ clang::Expr* E) {
   DXASSERT_NOMSG(E != nullptr);
+
+  if (IsHLSLBufferViewType(E->getType())) {
+    QualType targetType =
+        m_context->getConstType(hlsl::GetHLSLResourceResultType(E->getType()));
+    return ImplicitCastExpr::Create(*m_context, targetType,
+                                    CastKind::CK_FlatConversion, E, nullptr,
+                                    E->getValueKind());
+  }
   ArBasicKind basic = GetTypeElementKind(E->getType());
   if (!IS_BASIC_PRIMITIVE(basic)) {
     return E;
@@ -8408,6 +8456,14 @@ bool HLSLExternalSource::CanConvert(
     return false;
   }
 
+  // Cast cbuffer to its result value.
+  if ((SourceInfo.EltKind == AR_OBJECT_CONSTANT_BUFFER ||
+       SourceInfo.EltKind == AR_OBJECT_TEXTURE_BUFFER) &&
+      TargetInfo.ShapeKind == AR_TOBJ_COMPOUND) {
+    standard->Second = ICK_Flat_Conversion;
+    return hlsl::GetHLSLResourceResultType(source) == target;
+  }
+
   // Structure cast.
   SourceIsAggregate = SourceInfo.ShapeKind == AR_TOBJ_COMPOUND || SourceInfo.ShapeKind == AR_TOBJ_ARRAY;
   TargetIsAggregate = TargetInfo.ShapeKind == AR_TOBJ_COMPOUND || TargetInfo.ShapeKind == AR_TOBJ_ARRAY;
@@ -10244,11 +10300,11 @@ bool hlsl::LookupRecordMemberExprForHLSL(
   return false;
 }
 
-clang::ExprResult hlsl::MaybeConvertScalarToVector(
+clang::ExprResult hlsl::MaybeConvertMemberAccess(
   _In_ clang::Sema* self,
   _In_ clang::Expr* E)
 {
-  return HLSLExternalSource::FromSema(self)->MaybeConvertScalarToVector(E);
+  return HLSLExternalSource::FromSema(self)->MaybeConvertMemberAccess(E);
 }
 
 bool hlsl::TryStaticCastForHLSL(_In_ Sema* self, ExprResult &SrcExpr,

+ 19 - 0
tools/clang/test/HLSLFileCheck/hlsl/objects/Cbuffer/retCBV.hlsl

@@ -0,0 +1,19 @@
+// RUN: %dxc -E main -T ps_6_0  %s  | FileCheck %s
+
+// Make sure both a and b are used.
+// CHECK:extractvalue %dx.types.CBufRet.f32 %{{.*}}, 0
+// CHECK:extractvalue %dx.types.CBufRet.f32 %{{.*}}, 1
+struct S {
+  float a;
+  float b;
+};
+
+ConstantBuffer<S> c: register(b2, space5);;
+
+S getS() {
+  return c;
+}
+
+float main() : SV_Target {
+  return c.a + getS().b;
+}

+ 20 - 0
tools/clang/test/HLSLFileCheck/hlsl/objects/Cbuffer/retTBV.hlsl

@@ -0,0 +1,20 @@
+// RUN: %dxc -E main -T ps_6_0  %s  | FileCheck %s
+
+// Make sure both a and b are used.
+// CHECK:extractvalue %dx.types.ResRet.i32 %{{.*}}, 0
+// CHECK:extractvalue %dx.types.ResRet.i32 %{{.*}}, 1
+
+struct S {
+  float a;
+  float b;
+};
+
+TextureBuffer<S> c: register(t2, space5);;
+
+S getS() {
+  return c;
+}
+
+float main() : SV_Target {
+  return c.a + getS().b;
+}

+ 4 - 4
tools/clang/test/HLSLFileCheck/hlsl/objects/CbufferLegacy/cbuffer-struct.hlsl

@@ -8,13 +8,13 @@ struct S {
     float4 f;
 };
 
-// CHECK: error: 'const int' cannot be used as a type parameter where a struct is required
+// CHECK: error: 'int' cannot be used as a type parameter where a struct is required
 ConstantBuffer<int>      B1;
-// CHECK: error: 'const float2' cannot be used as a type parameter where a struct is required
+// CHECK: error: 'float2' cannot be used as a type parameter where a struct is required
 TextureBuffer<float2>    B2;
-// CHECK: error: 'const float3x4' cannot be used as a type parameter where a struct is required
+// CHECK: error: 'float3x4' cannot be used as a type parameter where a struct is required
 ConstantBuffer<float3x4> B3;
-// CHECK: error: 'const C' cannot be used as a type parameter where a struct is required
+
 TextureBuffer<C>         B4;
 // CHECK-NOT: const S
 ConstantBuffer<S>        B5;

+ 3 - 3
tools/clang/test/HLSLFileCheck/hlsl/objects/CbufferLegacy/cbuffer-structarray.hlsl

@@ -6,12 +6,12 @@ struct Foo {
 
 typedef Foo FooA[2];
 
-// CHECK: error: 'const FooA' (aka 'Foo const[2]') cannot be used as a type parameter where a struct is required
+// CHECK: error: 'FooA' (aka 'Foo [2]') cannot be used as a type parameter where a struct is required
 ConstantBuffer<FooA> CB1;
 
-// CHECK: error: 'const FooA' (aka 'Foo const[2]') cannot be used as a type parameter where a struct is required
+// CHECK: error: 'FooA' (aka 'Foo [2]') cannot be used as a type parameter where a struct is required
 ConstantBuffer<FooA> CB[4][3];
-// CHECK: error: 'const FooA' (aka 'Foo const[2]') cannot be used as a type parameter where a struct is required
+// CHECK: error: 'FooA' (aka 'Foo [2]') cannot be used as a type parameter where a struct is required
 TextureBuffer<FooA> TB[4][3];
 
 float4 main(int a : A) : SV_Target

+ 4 - 4
tools/clang/test/HLSLFileCheck/hlsl/objects/CbufferLegacy/ctbuf.hlsl

@@ -4,11 +4,11 @@ struct S {
   float4 f;
 };
 
-// CHECK: ConstantBuffer
-// CHECK: <line:9:16, col:19> col:19 myCBuffer 'const S'
+
+// CHECK: <line:9:1, col:19> col:19 myCBuffer 'ConstantBuffer<S>':'ConstantBuffer<S>'
 ConstantBuffer<S> myCBuffer;
-// CHECK: TextureBuffer
-// CHECK: <line:12:15, col:18> col:18 myTBffer 'const S'
+
+// CHECK: <line:12:1, col:18> col:18 myTBffer 'TextureBuffer<S>':'TextureBuffer<S>'
 TextureBuffer<S> myTBffer;
 
 // CHECK: cbuffer