Przeglądaj źródła

[spirv] Add support for ConstantBuffer (#590)

Also optimized codgen for consecutive array/struct indexing
expressions to avoid generating multiple OpAccesChains.
Lei Zhang 8 lat temu
rodzic
commit
516f3270c8

+ 52 - 31
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -67,19 +67,11 @@ DeclResultIdMapper::getDeclSpirvInfo(const NamedDecl *decl) const {
 
 uint32_t DeclResultIdMapper::getDeclResultId(const NamedDecl *decl) {
   if (const auto *info = getDeclSpirvInfo(decl))
-    if (const auto *bufferDecl =
-            dyn_cast<HLSLBufferDecl>(decl->getDeclContext())) {
+    if (info->indexInCTBuffer >= 0) {
       // If this is a VarDecl inside a HLSLBufferDecl, we need to do an extra
       // OpAccessChain to get the pointer to the variable since we created
       // a single variable for the whole buffer object.
 
-      uint32_t index = 0;
-      for (const auto *subDecl : bufferDecl->decls()) {
-        if (subDecl == decl)
-          break;
-        ++index;
-      }
-
       const uint32_t varType = typeTranslator.translateType(
           // Should only have VarDecls in a HLSLBufferDecl.
           cast<VarDecl>(decl)->getType(),
@@ -88,7 +80,7 @@ uint32_t DeclResultIdMapper::getDeclResultId(const NamedDecl *decl) {
           /*decorateLayout*/ true);
       return theBuilder.createAccessChain(
           theBuilder.getPointerType(varType, info->storageClass),
-          info->resultId, {theBuilder.getConstantInt32(index)});
+          info->resultId, {theBuilder.getConstantInt32(info->indexInCTBuffer)});
     } else {
       return info->resultId;
     }
@@ -100,7 +92,7 @@ uint32_t DeclResultIdMapper::getDeclResultId(const NamedDecl *decl) {
 uint32_t DeclResultIdMapper::createFnParam(uint32_t paramType,
                                            const ParmVarDecl *param) {
   const uint32_t id = theBuilder.addFnParam(paramType, param->getName());
-  astDecls[param] = {id, spv::StorageClass::Function};
+  astDecls[param] = {id, spv::StorageClass::Function, -1};
 
   return id;
 }
@@ -108,7 +100,7 @@ uint32_t DeclResultIdMapper::createFnParam(uint32_t paramType,
 uint32_t DeclResultIdMapper::createFnVar(uint32_t varType, const VarDecl *var,
                                          llvm::Optional<uint32_t> init) {
   const uint32_t id = theBuilder.addFnVar(varType, var->getName(), init);
-  astDecls[var] = {id, spv::StorageClass::Function};
+  astDecls[var] = {id, spv::StorageClass::Function, -1};
 
   return id;
 }
@@ -117,7 +109,7 @@ uint32_t DeclResultIdMapper::createFileVar(uint32_t varType, const VarDecl *var,
                                            llvm::Optional<uint32_t> init) {
   const uint32_t id = theBuilder.addModuleVar(
       varType, spv::StorageClass::Private, var->getName(), init);
-  astDecls[var] = {id, spv::StorageClass::Private};
+  astDecls[var] = {id, spv::StorageClass::Private, -1};
 
   return id;
 }
@@ -135,55 +127,64 @@ uint32_t DeclResultIdMapper::createExternVar(uint32_t varType,
 
   const uint32_t id = theBuilder.addModuleVar(varType, storageClass,
                                               var->getName(), llvm::None);
-  astDecls[var] = {id, storageClass};
+  astDecls[var] = {id, storageClass, -1};
   resourceVars.emplace_back(id, getResourceBinding(var),
                             var->getAttr<VKBindingAttr>());
 
   return id;
 }
 
-uint32_t DeclResultIdMapper::createExternVar(const HLSLBufferDecl *decl) {
-  // In the AST, cbuffer/tbuffer is represented as a HLSLBufferDecl, which is
-  // a DeclContext, and all fields in the buffer are represented as VarDecls.
-  // We cannot do the normal translation path, which will translate a field
-  // into a standalone variable. We need to create a single SPIR-V variable
-  // for the whole buffer.
-
+uint32_t
+DeclResultIdMapper::createVarOfExplicitLayoutStruct(const DeclContext *decl,
+                                                    llvm::StringRef typeName,
+                                                    llvm::StringRef varName) {
   // Collect the type and name for each field
   llvm::SmallVector<uint32_t, 4> fieldTypes;
   llvm::SmallVector<llvm::StringRef, 4> fieldNames;
   for (const auto *subDecl : decl->decls()) {
-    // All the fields should be VarDecls.
-    const auto *varDecl = cast<VarDecl>(subDecl);
+    // Implicit generated struct declarations should be ignored.
+    if (isa<CXXRecordDecl>(subDecl) && subDecl->isImplicit())
+      continue;
+
+    // The field can only be FieldDecl (for normal structs) or VarDecl (for
+    // HLSLBufferDecls).
+    assert(isa<VarDecl>(subDecl) || isa<FieldDecl>(subDecl));
+    const auto *declDecl = cast<DeclaratorDecl>(subDecl);
     // All fields are qualified with const. It will affect the debug name.
     // We don't need it here.
-    auto varType = varDecl->getType();
+    auto varType = declDecl->getType();
     varType.removeLocalConst();
 
     fieldTypes.push_back(typeTranslator.translateType(
-        varType, true, varDecl->hasAttr<HLSLRowMajorAttr>()));
-    fieldNames.push_back(varDecl->getName());
+        varType, true, declDecl->hasAttr<HLSLRowMajorAttr>()));
+    fieldNames.push_back(declDecl->getName());
   }
 
   // Get the type for the whole buffer
-  const std::string structName = "type." + decl->getName().str();
   auto decorations = typeTranslator.getLayoutDecorations(decl);
   decorations.push_back(Decoration::getBlock(*theBuilder.getSPIRVContext()));
   const uint32_t structType =
-      theBuilder.getStructType(fieldTypes, structName, fieldNames, decorations);
+      theBuilder.getStructType(fieldTypes, typeName, fieldNames, decorations);
 
   // Create the variable for the whole buffer
+  return theBuilder.addModuleVar(structType, spv::StorageClass::Uniform,
+                                 varName);
+}
+
+uint32_t DeclResultIdMapper::createCTBuffer(const HLSLBufferDecl *decl) {
+  const std::string structName = "type." + decl->getName().str();
   const std::string varName = "var." + decl->getName().str();
   const uint32_t bufferVar =
-      theBuilder.addModuleVar(structType, spv::StorageClass::Uniform, varName);
+      createVarOfExplicitLayoutStruct(decl, structName, varName);
 
   // We still register all VarDecls seperately here. All the VarDecls are
   // mapped to the <result-id> of the buffer object, which means when querying
   // querying the <result-id> for a certain VarDecl, we need to do an extra
   // OpAccessChain.
+  int index = 0;
   for (const auto *subDecl : decl->decls()) {
     const auto *varDecl = cast<VarDecl>(subDecl);
-    astDecls[varDecl] = {bufferVar, spv::StorageClass::Uniform};
+    astDecls[varDecl] = {bufferVar, spv::StorageClass::Uniform, index++};
   }
   resourceVars.emplace_back(bufferVar, getResourceBinding(decl),
                             decl->getAttr<VKBindingAttr>());
@@ -191,12 +192,32 @@ uint32_t DeclResultIdMapper::createExternVar(const HLSLBufferDecl *decl) {
   return bufferVar;
 }
 
+uint32_t DeclResultIdMapper::createCTBuffer(const VarDecl *decl) {
+  const auto *recordType = decl->getType()->getAs<RecordType>();
+  assert(recordType);
+  const auto *context = cast<HLSLBufferDecl>(decl->getDeclContext());
+  const bool isCBuffer = context->isCBuffer();
+
+  const std::string structName =
+      "type." + std::string(isCBuffer ? "ConstantBuffer." : "TextureBuffer") +
+      recordType->getDecl()->getName().str();
+  const uint32_t bufferVar = createVarOfExplicitLayoutStruct(
+      recordType->getDecl(), structName, decl->getName());
+
+  // We register the VarDecl here.
+  astDecls[decl] = {bufferVar, spv::StorageClass::Uniform, -1};
+  resourceVars.emplace_back(bufferVar, getResourceBinding(context),
+                            decl->getAttr<VKBindingAttr>());
+
+  return bufferVar;
+}
+
 uint32_t DeclResultIdMapper::getOrRegisterFnResultId(const FunctionDecl *fn) {
   if (const auto *info = getDeclSpirvInfo(fn))
     return info->resultId;
 
   const uint32_t id = theBuilder.getSPIRVContext()->takeNextId();
-  astDecls[fn] = {id, spv::StorageClass::Function};
+  astDecls[fn] = {id, spv::StorageClass::Function, -1};
 
   return id;
 }

+ 34 - 1
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -147,7 +147,27 @@ public:
 
   /// \brief Creates an external-visible variable and returns its <result-id>.
   uint32_t createExternVar(uint32_t varType, const VarDecl *var);
-  uint32_t createExternVar(const HLSLBufferDecl *decl);
+
+  /// \brief Creates a cbuffer/tbuffer from the given decl.
+  ///
+  /// In the AST, cbuffer/tbuffer is represented as a HLSLBufferDecl, which is
+  /// a DeclContext, and all fields in the buffer are represented as VarDecls.
+  /// We cannot do the normal translation path, which will translate a field
+  /// into a standalone variable. We need to create a single SPIR-V variable
+  /// for the whole buffer. When we refer to the field VarDecl later, we need
+  /// to do an extra OpAccessChain to get its pointer from the SPIR-V variable
+  /// standing for the whole buffer.
+  uint32_t createCTBuffer(const HLSLBufferDecl *decl);
+
+  /// \brief Creates a cbuffer/tbuffer from the given decl.
+  ///
+  /// In the AST, a variable whose type is ConstantBuffer/TextureBuffer is
+  /// represented as a VarDecl whose DeclContext is a HLSLBufferDecl. These
+  /// VarDecl's type is labelled as the struct upon which ConstantBuffer/
+  /// TextureBuffer is parameterized. For a such VarDecl, we need to create
+  /// a corresponding SPIR-V variable for it. Later referencing of such a
+  /// VarDecl does not need an extra OpAccessChain.
+  uint32_t createCTBuffer(const VarDecl *decl);
 
   /// \brief Sets the <result-id> of the entry function.
   void setEntryFunctionId(uint32_t id) { entryFunctionId = id; }
@@ -157,6 +177,9 @@ public:
   struct DeclSpirvInfo {
     uint32_t resultId;
     spv::StorageClass storageClass;
+    /// Value >= 0 means that this decl is a VarDecl inside a cbuffer/tbuffer
+    /// and this is the index; value < 0 means this is just a standalone decl.
+    int indexInCTBuffer;
   };
 
   /// \brief Returns the SPIR-V information for the given decl.
@@ -223,6 +246,16 @@ private:
   /// returns its result type.
   QualType getFnParamOrRetType(const DeclaratorDecl *decl) const;
 
+  /// Creates a variable of struct type with explicit layout decorations.
+  /// The sub-Decls in the given DeclContext will be treated as the struct
+  /// fields. The struct type will be named as typeName, and the variable
+  /// will be named as varName.
+  ///
+  /// Panics if the DeclContext is neither HLSLBufferDecl or RecordDecl.
+  uint32_t createVarOfExplicitLayoutStruct(const DeclContext *decl,
+                                           llvm::StringRef typeName,
+                                           llvm::StringRef varName);
+
   /// Creates all the stage variables mapped from semantics on the given decl
   /// and returns true on success.
   ///

+ 39 - 36
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -166,9 +166,15 @@ void SPIRVEmitter::HandleTranslationUnit(ASTContext &context) {
         workQueue.insert(funcDecl);
       }
     } else if (auto *varDecl = dyn_cast<VarDecl>(decl)) {
-      doVarDecl(varDecl);
+      if (isa<HLSLBufferDecl>(varDecl->getDeclContext())) {
+        // This is a VarDecl of a ConstantBuffer/TextureBuffer type.
+        (void)declIdMapper.createCTBuffer(varDecl);
+      } else {
+        doVarDecl(varDecl);
+      }
     } else if (auto *bufferDecl = dyn_cast<HLSLBufferDecl>(decl)) {
-      (void)declIdMapper.createExternVar(bufferDecl);
+      // This is a cbuffer/tbuffer decl.
+      (void)declIdMapper.createCTBuffer(bufferDecl);
     }
   }
 
@@ -993,20 +999,14 @@ void SPIRVEmitter::doSwitchStmt(const SwitchStmt *switchStmt,
 }
 
 uint32_t SPIRVEmitter::doArraySubscriptExpr(const ArraySubscriptExpr *expr) {
-  // The base of an ArraySubscriptExpr has a wrapping LValueToRValue implicit
-  // cast. We need to ingore it to avoid creating OpLoad.
-  const auto *baseExpr = expr->getBase()->IgnoreParenLValueCasts();
-
-  const uint32_t valType = typeTranslator.translateType(
-      // TODO: handle non-constant array types
-      astContext.getAsConstantArrayType(baseExpr->getType())->getElementType());
-  const uint32_t ptrType = theBuilder.getPointerType(
-      valType, declIdMapper.resolveStorageClass(baseExpr));
+  llvm::SmallVector<uint32_t, 4> indices;
+  const auto *base = collectArrayStructIndices(expr, &indices);
 
-  const uint32_t base = doExpr(baseExpr);
-  const uint32_t index = doExpr(expr->getIdx());
+  const uint32_t ptrType =
+      theBuilder.getPointerType(typeTranslator.translateType(expr->getType()),
+                                declIdMapper.resolveStorageClass(base));
 
-  return theBuilder.createAccessChain(ptrType, base, {index});
+  return theBuilder.createAccessChain(ptrType, doExpr(base), indices);
 }
 
 uint32_t SPIRVEmitter::doBinaryOperator(const BinaryOperator *expr) {
@@ -1834,15 +1834,13 @@ uint32_t SPIRVEmitter::doInitListExpr(const InitListExpr *expr) {
 
 uint32_t SPIRVEmitter::doMemberExpr(const MemberExpr *expr) {
   llvm::SmallVector<uint32_t, 4> indices;
+  const Expr *base = collectArrayStructIndices(expr, &indices);
 
-  const Expr *baseExpr = collectStructIndices(expr, &indices);
-  const uint32_t base = doExpr(baseExpr);
+  const uint32_t ptrType =
+      theBuilder.getPointerType(typeTranslator.translateType(expr->getType()),
+                                declIdMapper.resolveStorageClass(base));
 
-  const uint32_t fieldType = typeTranslator.translateType(expr->getType());
-  const uint32_t ptrType = theBuilder.getPointerType(
-      fieldType, declIdMapper.resolveStorageClass(baseExpr));
-
-  return theBuilder.createAccessChain(ptrType, base, indices);
+  return theBuilder.createAccessChain(ptrType, doExpr(base), indices);
 }
 
 uint32_t SPIRVEmitter::doUnaryOperator(const UnaryOperator *expr) {
@@ -2676,25 +2674,30 @@ uint32_t SPIRVEmitter::processMatrixBinaryOp(const Expr *lhs, const Expr *rhs,
   return 0;
 }
 
-const Expr *
-SPIRVEmitter::collectStructIndices(const MemberExpr *expr,
-                                   llvm::SmallVectorImpl<uint32_t> *indices) {
-  const Expr *base = expr->getBase();
-  if (const auto *memExpr = dyn_cast<MemberExpr>(base)) {
-    base = collectStructIndices(memExpr, indices);
-  } else {
-    indices->clear();
-  }
+const Expr *SPIRVEmitter::collectArrayStructIndices(
+    const Expr *expr, llvm::SmallVectorImpl<uint32_t> *indices) {
+  if (const auto *indexing = dyn_cast<MemberExpr>(expr)) {
+    const Expr *base = collectArrayStructIndices(indexing->getBase(), indices);
 
-  const auto *memberDecl = expr->getMemberDecl();
-  if (const auto *fieldDecl = dyn_cast<FieldDecl>(memberDecl)) {
+    // Append the index of the current level
+    const auto *fieldDecl = cast<FieldDecl>(indexing->getMemberDecl());
+    assert(fieldDecl);
     indices->push_back(theBuilder.getConstantInt32(fieldDecl->getFieldIndex()));
-  } else {
-    emitError("Decl '%0' in MemberExpr is not supported yet.")
-        << memberDecl->getDeclKindName();
+
+    return base;
+  }
+
+  if (const auto *indexing = dyn_cast<ArraySubscriptExpr>(expr)) {
+    // The base of an ArraySubscriptExpr has a wrapping LValueToRValue implicit
+    // cast. We need to ingore it to avoid creating OpLoad.
+    const Expr *thisBase = indexing->getBase()->IgnoreParenLValueCasts();
+    const Expr *base = collectArrayStructIndices(thisBase, indices);
+    indices->push_back(doExpr(indexing->getIdx()));
+    return base;
   }
 
-  return base;
+  // This the deepest we can go. No more array or struct indexing.
+  return expr;
 }
 
 uint32_t SPIRVEmitter::castToBool(const uint32_t fromVal, QualType fromType,

+ 5 - 4
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -221,10 +221,11 @@ private:
                                  const BinaryOperatorKind opcode);
 
   /// Collects all indices (SPIR-V constant values) from consecutive MemberExprs
-  /// and writes into indices. Returns the real base (the first Expr that is not
-  /// a MemberExpr).
-  const Expr *collectStructIndices(const MemberExpr *expr,
-                                   llvm::SmallVectorImpl<uint32_t> *indices);
+  /// or ArraySubscriptExprs and writes into indices. Returns the real base
+  /// (the first Expr that is not a MemberExpr or ArraySubscriptExpr).
+  const Expr *
+  collectArrayStructIndices(const Expr *expr,
+                            llvm::SmallVectorImpl<uint32_t> *indices);
 
 private:
   /// Processes the given expr, casts the result into the given bool (vector)

+ 3 - 5
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -368,11 +368,9 @@ TypeTranslator::getLayoutDecorations(const DeclContext *decl) {
   uint32_t offset = 0, index = 0;
 
   for (const auto *field : decl->decls()) {
-    if (const auto *f = dyn_cast<CXXRecordDecl>(field)) {
-      // Implicit generated struct declarations should be ignored.
-      if (f->isImplicit())
-        continue;
-    }
+    // Implicit generated struct declarations should be ignored.
+    if (isa<CXXRecordDecl>(field) && field->isImplicit())
+      continue;
 
     // The field can only be FieldDecl (for normal structs) or VarDecl (for
     // HLSLBufferDecls).

+ 6 - 12
tools/clang/test/CodeGenSPIRV/op.array.access.hlsl

@@ -17,20 +17,14 @@ float main(float val: A, uint index: B) : C {
 
 // CHECK:       [[val:%\d+]] = OpLoad %float %val
 // CHECK-NEXT:  [[idx:%\d+]] = OpLoad %uint %index
-// CHECK-NEXT: [[ptr0:%\d+]] = OpAccessChain %_ptr_Function__arr_S_uint_16 %var [[idx]]
-// CHECK-NEXT: [[ptr1:%\d+]] = OpAccessChain %_ptr_Function_S [[ptr0]] %int_1
-// CHECK-NEXT: [[ptr2:%\d+]] = OpAccessChain %_ptr_Function__arr_float_uint_4 [[ptr1]] %int_0
-// CHECK-NEXT: [[ptr3:%\d+]] = OpAccessChain %_ptr_Function_float [[ptr2]] %int_2
-// CHECK-NEXT:                 OpStore [[ptr3]] [[val]]
+// CHECK-NEXT: [[ptr0:%\d+]] = OpAccessChain %_ptr_Function_float %var [[idx]] %int_1 %int_0 %int_2
+// CHECK-NEXT:                 OpStore [[ptr0]] [[val]]
 
     var[index][1].f[2] = val;
-// CHECK:      [[ptr0:%\d+]] = OpAccessChain %_ptr_Function__arr_S_uint_16 %var %int_0
-// CHECK-NEXT:  [[idx:%\d+]] = OpLoad %uint %index
-// CHECK-NEXT: [[ptr1:%\d+]] = OpAccessChain %_ptr_Function_S [[ptr0]] [[idx]]
-// CHECK-NEXT: [[ptr2:%\d+]] = OpAccessChain %_ptr_Function__arr_float_uint_4 [[ptr1]] %int_1
-// CHECK-NEXT:  [[idx:%\d+]] = OpLoad %uint %index
-// CHECK-NEXT: [[ptr3:%\d+]] = OpAccessChain %_ptr_Function_float [[ptr2]] [[idx]]
-// CHECK-NEXT: [[load:%\d+]] = OpLoad %float [[ptr3]]
+// CHECK-NEXT: [[idx0:%\d+]] = OpLoad %uint %index
+// CHECK-NEXT: [[idx1:%\d+]] = OpLoad %uint %index
+// CHECK:      [[ptr0:%\d+]] = OpAccessChain %_ptr_Function_float %var %int_0 [[idx0]] %int_1 [[idx1]]
+// CHECK-NEXT: [[load:%\d+]] = OpLoad %float [[ptr0]]
 // CHECK-NEXT:                 OpStore %r [[load]]
     r = var[0][index].g[index];
 

+ 37 - 0
tools/clang/test/CodeGenSPIRV/op.constant-buffer.access.hlsl

@@ -0,0 +1,37 @@
+// Run: %dxc -T vs_6_0 -E main
+
+struct S {
+    float  f;
+};
+
+struct T {
+    float    a;
+    float2   b;
+    float3x4 c;
+    S        s;
+    float    t[4];
+};
+
+
+ConstantBuffer<T> MyCbuffer : register(b1);
+
+float main() : A {
+// CHECK:      [[a:%\d+]] = OpAccessChain %_ptr_Uniform_float %MyCbuffer %int_0
+// CHECK-NEXT: {{%\d+}} = OpLoad %float [[a]]
+
+// CHECK:      [[b:%\d+]] = OpAccessChain %_ptr_Uniform_v2float %MyCbuffer %int_1
+// CHECK-NEXT: [[b0:%\d+]] = OpAccessChain %_ptr_Uniform_float [[b]] %int_0
+// CHECK-NEXT: {{%\d+}} = OpLoad %float [[b0]]
+
+// CHECK:      [[c:%\d+]] = OpAccessChain %_ptr_Uniform_mat3v4float %MyCbuffer %int_2
+// CHECK-NEXT: [[c12:%\d+]] = OpAccessChain %_ptr_Uniform_float [[c]] %uint_1 %uint_2
+// CHECK-NEXT: {{%\d+}} = OpLoad %float [[c12]]
+
+// CHECK:      [[s:%\d+]] = OpAccessChain %_ptr_Uniform_float %MyCbuffer %int_3 %int_0
+// CHECK-NEXT: {{%\d+}} = OpLoad %float [[s]]
+
+// CHECK:      [[t:%\d+]] = OpAccessChain %_ptr_Uniform_float %MyCbuffer %int_4 %int_3
+// CHECK-NEXT: {{%\d+}} = OpLoad %float [[t]]
+    return MyCbuffer.a + MyCbuffer.b.x + MyCbuffer.c[1][2] + MyCbuffer.s.f + MyCbuffer.t[3];
+}
+

+ 0 - 4
tools/clang/test/CodeGenSPIRV/type.cbuffer.hlsl

@@ -1,9 +1,5 @@
 // Run: %dxc -T vs_6_0 -E main
 
-// CHECK:      OpName %S "S"
-// CHECK-NEXT: OpMemberName %S 0 "f1"
-// CHECK-NEXT: OpMemberName %S 1 "f2"
-
 // CHECK:      OpName %type_MyCbuffer "type.MyCbuffer"
 // CHECK-NEXT: OpMemberName %type_MyCbuffer 0 "a"
 // CHECK-NEXT: OpMemberName %type_MyCbuffer 1 "b"

+ 35 - 0
tools/clang/test/CodeGenSPIRV/type.constant-buffer.hlsl

@@ -0,0 +1,35 @@
+// Run: %dxc -T vs_6_0 -E main
+
+// CHECK:      OpName %type_ConstantBuffer_T "type.ConstantBuffer.T"
+// CHECK-NEXT: OpMemberName %type_ConstantBuffer_T 0 "a"
+// CHECK-NEXT: OpMemberName %type_ConstantBuffer_T 1 "b"
+// CHECK-NEXT: OpMemberName %type_ConstantBuffer_T 2 "c"
+// CHECK-NEXT: OpMemberName %type_ConstantBuffer_T 3 "d"
+// CHECK-NEXT: OpMemberName %type_ConstantBuffer_T 4 "s"
+// CHECK-NEXT: OpMemberName %type_ConstantBuffer_T 5 "t"
+
+// CHECK:      OpName %MyCbuffer "MyCbuffer"
+// CHECK:      OpName %AnotherCBuffer "AnotherCBuffer"
+struct S {
+    float  f1;
+    float3 f2;
+};
+
+// CHECK: %type_ConstantBuffer_T = OpTypeStruct %bool %int %v2uint %mat3v4float %S %_arr_float_uint_4
+// CHECK: %_ptr_Uniform_type_ConstantBuffer_T = OpTypePointer Uniform %type_ConstantBuffer_T
+struct T {
+    bool     a;
+    int      b;
+    uint2    c;
+    float3x4 d;
+    S        s;
+    float    t[4];
+};
+
+// CHECK: %MyCbuffer = OpVariable %_ptr_Uniform_type_ConstantBuffer_T Uniform
+ConstantBuffer<T> MyCbuffer : register(b1);
+// CHECK: %AnotherCBuffer = OpVariable %_ptr_Uniform_type_ConstantBuffer_T Uniform
+ConstantBuffer<T> AnotherCBuffer : register(b2);
+
+void main() {
+}

+ 12 - 32
tools/clang/test/CodeGenSPIRV/var.init.array.hlsl

@@ -31,43 +31,33 @@ void main() {
     T1 val1[2];
 
 // val2[0]: Construct T2.e from T1.c.b[0]
-// CHECK:       [[val1_0:%\d+]] = OpAccessChain %_ptr_Function_T1 %val1 %uint_0
-// CHECK-NEXT:  [[T1_c_b:%\d+]] = OpAccessChain %_ptr_Function__arr_v2float_uint_2 [[val1_0]] %int_0 %int_0
-// CHECK-NEXT:     [[b_0:%\d+]] = OpAccessChain %_ptr_Function_v2float [[T1_c_b]] %uint_0
+// CHECK:          [[b_0:%\d+]] = OpAccessChain %_ptr_Function_v2float %val1 %uint_0 %int_0 %int_0 %uint_0
 // CHECK-NEXT: [[b_0_val:%\d+]] = OpLoad %v2float [[b_0]]
 // CHECK-NEXT:   [[e_val:%\d+]] = OpCompositeConstruct %S1 [[b_0_val]]
 
 // val2[0]: Construct T2.f from T1.c.b[1]
-// CHECK-NEXT:  [[val1_0:%\d+]] = OpAccessChain %_ptr_Function_T1 %val1 %uint_0
-// CHECK-NEXT:  [[T1_c_b:%\d+]] = OpAccessChain %_ptr_Function__arr_v2float_uint_2 [[val1_0]] %int_0 %int_0
-// CHECK-NEXT:     [[b_1:%\d+]] = OpAccessChain %_ptr_Function_v2float [[T1_c_b]] %uint_1
+// CHECK-NEXT:     [[b_1:%\d+]] = OpAccessChain %_ptr_Function_v2float %val1 %uint_0 %int_0 %int_0 %uint_1
 // CHECK-NEXT: [[b_1_val:%\d+]] = OpLoad %v2float [[b_1]]
 // CHECK-NEXT:   [[f_val:%\d+]] = OpCompositeConstruct %S1 [[b_1_val]]
 
 // val2[0]: Read T1.d as T2.g
-// CHECK-NEXT:  [[val1_0:%\d+]] = OpAccessChain %_ptr_Function_T1 %val1 %uint_0
-// CHECK-NEXT:    [[T1_d:%\d+]] = OpAccessChain %_ptr_Function_S2 [[val1_0]] %int_1
+// CHECK-NEXT:    [[T1_d:%\d+]] = OpAccessChain %_ptr_Function_S2 %val1 %uint_0 %int_1
 // CHECK-NEXT:   [[d_val:%\d+]] = OpLoad %S2 [[T1_d]]
 
 // CHECK-NEXT:  [[val2_0:%\d+]] = OpCompositeConstruct %T2 [[e_val]] [[f_val]] [[d_val]]
 
 // val2[1]: Construct T2.e from T1.c.b[0]
-// CHECK-NEXT:  [[val1_1:%\d+]] = OpAccessChain %_ptr_Function_T1 %val1 %uint_1
-// CHECK-NEXT:  [[T1_c_b:%\d+]] = OpAccessChain %_ptr_Function__arr_v2float_uint_2 [[val1_1]] %int_0 %int_0
-// CHECK-NEXT:     [[b_0:%\d+]] = OpAccessChain %_ptr_Function_v2float [[T1_c_b]] %uint_0
+// CHECK-NEXT:     [[b_0:%\d+]] = OpAccessChain %_ptr_Function_v2float %val1 %uint_1 %int_0 %int_0 %uint_0
 // CHECK-NEXT: [[b_0_val:%\d+]] = OpLoad %v2float [[b_0]]
 // CHECK-NEXT:   [[e_val:%\d+]] = OpCompositeConstruct %S1 [[b_0_val]]
 
 // val2[1]: Construct T2.f from T1.c.b[1]
-// CHECK-NEXT:  [[val1_1:%\d+]] = OpAccessChain %_ptr_Function_T1 %val1 %uint_1
-// CHECK-NEXT:  [[T1_c_b:%\d+]] = OpAccessChain %_ptr_Function__arr_v2float_uint_2 [[val1_1]] %int_0 %int_0
-// CHECK-NEXT:     [[b_1:%\d+]] = OpAccessChain %_ptr_Function_v2float [[T1_c_b]] %uint_1
+// CHECK-NEXT:     [[b_1:%\d+]] = OpAccessChain %_ptr_Function_v2float %val1 %uint_1 %int_0 %int_0 %uint_1
 // CHECK-NEXT: [[b_1_val:%\d+]] = OpLoad %v2float [[b_1]]
 // CHECK-NEXT:   [[f_val:%\d+]] = OpCompositeConstruct %S1 [[b_1_val]]
 
 // val2[1]: Read T1.d as T2.g
-// CHECK-NEXT:  [[val1_1:%\d+]] = OpAccessChain %_ptr_Function_T1 %val1 %uint_1
-// CHECK-NEXT:    [[T1_d:%\d+]] = OpAccessChain %_ptr_Function_S2 [[val1_1]] %int_1
+// CHECK-NEXT:    [[T1_d:%\d+]] = OpAccessChain %_ptr_Function_S2 %val1 %uint_1 %int_1
 // CHECK-NEXT:   [[d_val:%\d+]] = OpLoad %S2 [[T1_d]]
 
 // CHECK-NEXT:  [[val2_1:%\d+]] = OpCompositeConstruct %T2 [[e_val]] [[f_val]] [[d_val]]
@@ -77,27 +67,19 @@ void main() {
     T2 val2[2] = {val1};
 
 // val3[0]: Construct T3.h from T1.c.b[0]
-// CHECK:       [[val1_0:%\d+]] = OpAccessChain %_ptr_Function_T1 %val1 %int_0
-// CHECK-NEXT:  [[T1_c_b:%\d+]] = OpAccessChain %_ptr_Function__arr_v2float_uint_2 [[val1_0]] %int_0 %int_0
-// CHECK-NEXT:     [[b_0:%\d+]] = OpAccessChain %_ptr_Function_v2float [[T1_c_b]] %uint_0
+// CHECK-NEXT:     [[b_0:%\d+]] = OpAccessChain %_ptr_Function_v2float %val1 %int_0 %int_0 %int_0 %uint_0
 // CHECK-NEXT:   [[h_val:%\d+]] = OpLoad %v2float [[b_0]]
 
 // val3[0]: Construct T3.i from T1.c.b[1]
-// CHECK-NEXT:  [[val1_0:%\d+]] = OpAccessChain %_ptr_Function_T1 %val1 %int_0
-// CHECK-NEXT:  [[T1_c_b:%\d+]] = OpAccessChain %_ptr_Function__arr_v2float_uint_2 [[val1_0]] %int_0 %int_0
-// CHECK-NEXT:     [[b_1:%\d+]] = OpAccessChain %_ptr_Function_v2float [[T1_c_b]] %uint_1
+// CHECK-NEXT:     [[b_1:%\d+]] = OpAccessChain %_ptr_Function_v2float %val1 %int_0 %int_0 %int_0 %uint_1
 // CHECK-NEXT:   [[i_val:%\d+]] = OpLoad %v2float [[b_1]]
 
 // val3[0]: Construct T3.j from T1.d.b[0]
-// CHECK-NEXT:  [[val1_0:%\d+]] = OpAccessChain %_ptr_Function_T1 %val1 %int_0
-// CHECK-NEXT:  [[T1_d_b:%\d+]] = OpAccessChain %_ptr_Function__arr_v2float_uint_2 [[val1_0]] %int_1 %int_0
-// CHECK-NEXT:     [[b_0:%\d+]] = OpAccessChain %_ptr_Function_v2float [[T1_d_b]] %uint_0
+// CHECK-NEXT:     [[b_0:%\d+]] = OpAccessChain %_ptr_Function_v2float %val1 %int_0 %int_1 %int_0 %uint_0
 // CHECK-NEXT:   [[j_val:%\d+]] = OpLoad %v2float [[b_0]]
 
 // val3[0]: Construct T3.k from T1.d.b[1]
-// CHECK-NEXT:  [[val1_0:%\d+]] = OpAccessChain %_ptr_Function_T1 %val1 %int_0
-// CHECK-NEXT:  [[T1_d_b:%\d+]] = OpAccessChain %_ptr_Function__arr_v2float_uint_2 [[val1_0]] %int_1 %int_0
-// CHECK-NEXT:     [[b_1:%\d+]] = OpAccessChain %_ptr_Function_v2float [[T1_d_b]] %uint_1
+// CHECK-NEXT:     [[b_1:%\d+]] = OpAccessChain %_ptr_Function_v2float %val1 %int_0 %int_1 %int_0 %uint_1
 // CHECK-NEXT:   [[k_val:%\d+]] = OpLoad %v2float [[b_1]]
 
 // CHECK-NEXT:  [[val3_0:%\d+]] = OpCompositeConstruct %T3 [[h_val]] [[i_val]] [[j_val]] [[k_val]]
@@ -110,13 +92,11 @@ void main() {
 // CHECK-NEXT:   [[h_val:%\d+]] = OpLoad %v2float [[s1_a]]
 
 // val3[2]: Construct T3.i from S2.b[0]
-// CHECK-NEXT:    [[s2_b:%\d+]] = OpAccessChain %_ptr_Function__arr_v2float_uint_2 %s2 %int_0
-// CHECK-NEXT:  [[s2_b_0:%\d+]] = OpAccessChain %_ptr_Function_v2float [[s2_b]] %uint_0
+// CHECK-NEXT:  [[s2_b_0:%\d+]] = OpAccessChain %_ptr_Function_v2float %s2 %int_0 %uint_0
 // CHECK-NEXT:   [[i_val:%\d+]] = OpLoad %v2float [[s2_b_0]]
 
 // val3[2]: Construct T3.j from S2.b[1]
-// CHECK-NEXT:    [[s2_b:%\d+]] = OpAccessChain %_ptr_Function__arr_v2float_uint_2 %s2 %int_0
-// CHECK-NEXT:  [[s2_b_1:%\d+]] = OpAccessChain %_ptr_Function_v2float [[s2_b]] %uint_1
+// CHECK-NEXT:  [[s2_b_1:%\d+]] = OpAccessChain %_ptr_Function_v2float %s2 %int_0 %uint_1
 // CHECK-NEXT:   [[j_val:%\d+]] = OpLoad %v2float [[s2_b_1]]
 
 // val3[2]: Construct T3.k from S1.a

+ 1 - 0
tools/clang/test/CodeGenSPIRV/vk.binding.explicit.hlsl

@@ -31,6 +31,7 @@ Buffer<int> myBuffer : register(t1, space0);
 RWBuffer<float4> myRWBuffer : register(u0, space1);
 
 // TODO: support [[vk::binding()]] on cbuffer
+// TODO: support [[vk::binding()]] on ConstantBuffer
 
 float4 main() : SV_Target {
     return 1.0;

+ 8 - 0
tools/clang/test/CodeGenSPIRV/vk.binding.implicit.hlsl

@@ -30,6 +30,14 @@ Buffer<int> myBuffer;
 // CHECK-NEXT: OpDecorate %myRWBuffer Binding 6
 RWBuffer<float4> myRWBuffer;
 
+struct S {
+    float4 f;
+};
+
+// CHECK:      OpDecorate %myCbuffer2 DescriptorSet 0
+// CHECK-NEXT: OpDecorate %myCbuffer2 Binding 7
+ConstantBuffer<S> myCbuffer2;
+
 float4 main() : SV_Target {
     return 1.0;
 }

+ 12 - 0
tools/clang/test/CodeGenSPIRV/vk.binding.register.hlsl

@@ -42,6 +42,18 @@ Buffer<int> myBuffer : register(t3, space0);
 // CHECK-NEXT: OpDecorate %myRWBuffer Binding 4
 RWBuffer<float4> myRWBuffer : register(u4, space1);
 
+struct S {
+    float4 f;
+};
+
+// CHECK:      OpDecorate %myCbuffer2 DescriptorSet 2
+// CHECK-NEXT: OpDecorate %myCbuffer2 Binding 2
+ConstantBuffer<S> myCbuffer2 : register(b2, space2);
+
+// CHECK:      OpDecorate %myCbuffer3 DescriptorSet 3
+// CHECK-NEXT: OpDecorate %myCbuffer3 Binding 2
+ConstantBuffer<S> myCbuffer3 : register(b2, space3);
+
 float4 main() : SV_Target {
     return 1.0;
 }

+ 7 - 1
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -49,6 +49,9 @@ TEST_F(FileTest, SamplerTypes) { runFileTest("type.sampler.hlsl"); }
 TEST_F(FileTest, TextureTypes) { runFileTest("type.texture.hlsl"); }
 TEST_F(FileTest, BufferType) { runFileTest("type.buffer.hlsl"); }
 TEST_F(FileTest, CBufferType) { runFileTest("type.cbuffer.hlsl"); }
+TEST_F(FileTest, ConstantBufferType) {
+  runFileTest("type.constant-buffer.hlsl");
+}
 TEST_F(FileTest, ByteAddressBufferTypes) {
   runFileTest("type.byte-address-buffer.hlsl");
 }
@@ -194,8 +197,11 @@ TEST_F(FileTest, OpMatrixAccess1x1) {
 
 // For struct & array accessing operator
 TEST_F(FileTest, OpStructAccess) { runFileTest("op.struct.access.hlsl"); }
+TEST_F(FileTest, OpArrayAccess) { runFileTest("op.array.access.hlsl"); }
 TEST_F(FileTest, OpCBufferAccess) { runFileTest("op.cbuffer.access.hlsl"); }
-TEST_F(FileTest, OpStructArray) { runFileTest("op.array.access.hlsl"); }
+TEST_F(FileTest, OpConstantBufferAccess) {
+  runFileTest("op.constant-buffer.access.hlsl");
+}
 
 // For Buffer/RWBuffer accessing operator
 TEST_F(FileTest, OpBufferAccess) { runFileTest("op.buffer.access.hlsl"); }