Browse Source

[spirv] Legalization: associated counter in structs (#997)

For structs containing RW/Append/Consume structured buffers, to
properly track aliasing and use the correct associated counter
for each buffer after legalization, we need to create temporary
counter variables for all struct fields that are RW/Append/Consume
structured buffers and assign them accordingly if the corresponding
buffer field is updated.

Because of structs, now we have four forms an alias RW/Append/
Consume structured buffer can be in:

* 1 (AssocCounter#1). A stand-alone variable,
* 2 (AssocCounter#2). A struct field,
* 3 (AssocCounter#3). A struct containing alias fields,
* 4 (AssocCounter#4). A nested struct containing alias fields.

For AssocCounter#3 and AssocCounter#4, it means we need to update
all fields' associated counters.

This commit only handles the first three forms.
Lei Zhang 7 years ago
parent
commit
6f6c600ebc

+ 100 - 20
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -108,6 +108,35 @@ std::string StageVar::getSemanticStr() const {
   return ss.str();
 }
 
+uint32_t CounterIdAliasPair::get(ModuleBuilder &builder,
+                                 TypeTranslator &translator) const {
+  if (isAlias) {
+    const uint32_t counterVarType = builder.getPointerType(
+        translator.getACSBufferCounter(), spv::StorageClass::Uniform);
+    return builder.createLoad(counterVarType, resultId);
+  }
+  return resultId;
+}
+
+const CounterIdAliasPair *
+CounterVarFields::get(const llvm::SmallVectorImpl<uint32_t> &indices) const {
+  for (const auto &field : fields)
+    if (field.indices == indices)
+      return &field.counterVar;
+  return nullptr;
+}
+
+void CounterVarFields::assign(const CounterVarFields &srcFields,
+                              ModuleBuilder &builder,
+                              TypeTranslator &translator) const {
+  for (const auto &field : fields) {
+    const auto *srcField = srcFields.get(field.indices);
+    // TODO: this will fail for AssocCounter#4.
+    assert(srcField);
+    field.counterVar.assign(*srcField, builder, translator);
+  }
+}
+
 DeclResultIdMapper::SemanticInfo
 DeclResultIdMapper::getStageVarSemantic(const ValueDecl *decl) {
   for (auto *annotation : decl->getUnusualAnnotations()) {
@@ -260,12 +289,17 @@ uint32_t DeclResultIdMapper::createFnParam(const ParmVarDecl *param) {
   return id;
 }
 
-void DeclResultIdMapper::createFnParamCounterVar(const ParmVarDecl *param) {
-  if (counterVars.count(param))
-    return;
+void DeclResultIdMapper::createCounterVarForDecl(const DeclaratorDecl *decl) {
+  const QualType declType = getTypeOrFnRetType(decl);
 
-  if (TypeTranslator::isRWAppendConsumeSBuffer(param->getType()))
-    createCounterVar(param, /*isAlias=*/true);
+  if (!counterVars.count(decl) &&
+      TypeTranslator::isRWAppendConsumeSBuffer(declType)) {
+    createCounterVar(decl, /*isAlias=*/true);
+  } else if (!fieldCounterVars.count(decl) && declType->isStructureType() &&
+             // Exclude other resource types which are represented as structs
+             !hlsl::IsHLSLResourceType(declType)) {
+    createFieldCounterVars(decl);
+  }
 }
 
 uint32_t DeclResultIdMapper::createFnVar(const VarDecl *var,
@@ -288,9 +322,7 @@ uint32_t DeclResultIdMapper::createFileVar(const VarDecl *var,
       getTypeAndCreateCounterForPotentialAliasVar(var, &isAlias, &info);
   const uint32_t id = theBuilder.addModuleVar(type, spv::StorageClass::Private,
                                               var->getName(), init);
-  info.setResultId(id);
-  if (!isAlias)
-    info.setStorageClass(spv::StorageClass::Private);
+  info.setResultId(id).setStorageClass(spv::StorageClass::Private);
 
   return id;
 }
@@ -517,17 +549,41 @@ uint32_t DeclResultIdMapper::getOrRegisterFnResultId(const FunctionDecl *fn) {
   return id;
 }
 
-const CounterIdAliasPair *
-DeclResultIdMapper::getCounterIdAliasPair(const DeclaratorDecl *decl) {
-  const auto counter = counterVars.find(decl);
-  if (counter != counterVars.end())
-    return &counter->second;
+const CounterIdAliasPair *DeclResultIdMapper::getCounterIdAliasPair(
+    const DeclaratorDecl *decl,
+    const llvm::SmallVector<uint32_t, 4> *indices) const {
+  if (indices) {
+    // Indices are provided. Walk through the fields of the decl.
+    const auto counter = fieldCounterVars.find(decl);
+    if (counter != fieldCounterVars.end())
+      return counter->second.get(*indices);
+  } else {
+    // No indices. Check the stand-alone entities.
+    const auto counter = counterVars.find(decl);
+    if (counter != counterVars.end())
+      return &counter->second;
+  }
+  return nullptr;
+}
+
+const CounterVarFields *
+DeclResultIdMapper::getCounterVarFields(const DeclaratorDecl *decl) const {
+  const auto found = fieldCounterVars.find(decl);
+  if (found != fieldCounterVars.end())
+    return &found->second;
   return nullptr;
 }
 
-void DeclResultIdMapper::createCounterVar(const DeclaratorDecl *decl,
-                                          bool isAlias) {
-  const std::string counterName = "counter.var." + decl->getName().str();
+void DeclResultIdMapper::createCounterVar(
+    const DeclaratorDecl *decl, bool isAlias,
+    const llvm::SmallVector<uint32_t, 4> *indices) {
+  std::string counterName = "counter.var." + decl->getName().str();
+  if (indices) {
+    // Append field indices to the name
+    for (const auto index : *indices)
+      counterName += "." + std::to_string(index);
+  }
+
   uint32_t counterType = typeTranslator.getACSBufferCounter();
   // {RW|Append|Consume}StructuredBuffer are all in Uniform storage class.
   // Alias counter variables should be created into the Private storage class.
@@ -552,7 +608,33 @@ void DeclResultIdMapper::createCounterVar(const DeclaratorDecl *decl,
                               decl->getAttr<VKCounterBindingAttr>(), true);
   }
 
-  counterVars[decl] = {counterId, isAlias};
+  if (indices)
+    fieldCounterVars[decl].append(*indices, counterId);
+  else
+    counterVars[decl] = {counterId, isAlias};
+}
+
+void DeclResultIdMapper::createFieldCounterVars(
+    const DeclaratorDecl *rootDecl, const DeclaratorDecl *decl,
+    llvm::SmallVector<uint32_t, 4> *indices) {
+  const QualType type = getTypeOrFnRetType(decl);
+  const auto *recordType = type->getAs<RecordType>();
+  assert(recordType);
+  const auto *recordDecl = recordType->getDecl();
+
+  for (const auto *field : recordDecl->fields()) {
+    indices->push_back(field->getFieldIndex()); // Build up the index chain
+
+    const QualType fieldType = field->getType();
+    if (TypeTranslator::isRWAppendConsumeSBuffer(fieldType))
+      createCounterVar(rootDecl, /*isAlias=*/true, indices);
+    else if (fieldType->isStructureType() &&
+             !hlsl::IsHLSLResourceType(fieldType))
+      // Go recursively into all nested structs
+      createFieldCounterVars(rootDecl, field, indices);
+
+    indices->pop_back();
+  }
 }
 
 uint32_t
@@ -1889,9 +1971,7 @@ uint32_t DeclResultIdMapper::getTypeAndCreateCounterForPotentialAliasVar(
   if (genAlias) {
     needsLegalization = true;
 
-    if (!counterVars.count(decl) &&
-        TypeTranslator::isRWAppendConsumeSBuffer(type))
-      createCounterVar(decl, /*isAlias=*/true);
+    createCounterVarForDecl(decl);
 
     if (info)
       info->setContainsAliasComponent(true);

+ 126 - 20
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -140,22 +140,12 @@ public:
 
   /// Returns the pointer to the counter variable. Dereferences first if this is
   /// an alias to a counter variable.
-  uint32_t get(ModuleBuilder &builder, TypeTranslator &translator) const {
-    if (isAlias) {
-      const uint32_t counterVarType = builder.getPointerType(
-          translator.getACSBufferCounter(), spv::StorageClass::Uniform);
-      return builder.createLoad(counterVarType, resultId);
-    }
-    return resultId;
-  }
+  uint32_t get(ModuleBuilder &builder, TypeTranslator &translator) const;
 
   /// Stores the counter variable's pointer in srcPair to the curent counter
   /// variable. The current counter variable must be an alias.
-  void assign(const CounterIdAliasPair &srcPair, ModuleBuilder &builder,
-              TypeTranslator &translator) const {
-    assert(isAlias);
-    builder.createStore(resultId, srcPair.get(builder, translator));
-  }
+  inline void assign(const CounterIdAliasPair &srcPair, ModuleBuilder &builder,
+                     TypeTranslator &translator) const;
 
 private:
   uint32_t resultId;
@@ -163,6 +153,79 @@ private:
   bool isAlias;
 };
 
+/// A class for holding all the counter variables associated with a struct's
+/// fields
+///
+/// A alias local RW/Append/Consume structured buffer will need an associated
+/// counter variable generated. There are four forms such an alias buffer can
+/// be:
+///
+/// 1 (AssocCounter#1). A stand-alone variable,
+/// 2 (AssocCounter#2). A struct field,
+/// 3 (AssocCounter#3). A struct containing alias fields,
+/// 4 (AssocCounter#4). A nested struct containing alias fields.
+///
+/// We consider the first two cases as *final* alias entities; The last two
+/// cases are called as *intermediate* alias entities, since we can still
+/// decompose them and get final alias entities.
+///
+/// We need to create an associated counter variable no matter which form the
+/// alias buffer is in, which means we need to recursively visit all fields of a
+/// struct to discover if it's not AssocCounter#1. That means a hierarchy.
+///
+/// The purpose of this class is to provide such hierarchy in a *flattened* way.
+/// Each field's associated counter is represented with an index vector and the
+/// counter's <result-id>. For example, for the following structs,
+///
+/// struct S {
+///       RWStructuredBuffer s1;
+///   AppendStructuredBuffer s2;
+/// };
+///
+/// struct T {
+///   S t1;
+///   S t2;
+/// };
+///
+/// An instance of T will have four associated counters for
+///   field: indices, <result-id>
+///   t1.s1: [0, 0], <id-1>
+///   t1.s2: [0, 1], <id-2>
+///   t2.s1: [1, 0], <id-3>
+///   t2.s2: [1, 1], <id-4>
+class CounterVarFields {
+public:
+  CounterVarFields() = default;
+
+  /// Registers a field's associated counter.
+  void append(const llvm::SmallVector<uint32_t, 4> &indices, uint32_t counter) {
+    fields.emplace_back(indices, counter);
+  }
+
+  /// Returns the counter associated with the field at the given indices if it
+  /// has. Returns nullptr otherwise.
+  const CounterIdAliasPair *
+  get(const llvm::SmallVectorImpl<uint32_t> &indices) const;
+
+  /// Assigns to all the fields' associated counter from the srcFields.
+  /// This is for assigning a struct as whole: we need to update all the
+  /// associated counters in the target struct.
+  void assign(const CounterVarFields &srcFields, ModuleBuilder &builder,
+              TypeTranslator &translator) const;
+
+private:
+  struct IndexCounterPair {
+    IndexCounterPair(const llvm::SmallVector<uint32_t, 4> &idx,
+                     uint32_t counter)
+        : indices(idx), counterVar(counter, true) {}
+
+    llvm::SmallVector<uint32_t, 4> indices; ///< Index vector
+    CounterIdAliasPair counterVar;          ///< Counter variable information
+  };
+
+  llvm::SmallVector<IndexCounterPair, 4> fields;
+};
+
 /// \brief The class containing mappings from Clang frontend Decls to their
 /// corresponding SPIR-V <result-id>s.
 ///
@@ -216,7 +279,7 @@ public:
   /// This is meant to be used for forward-declared functions.
   ///
   /// Note: legalization specific code
-  void createFnParamCounterVar(const ParmVarDecl *param);
+  inline void createFnParamCounterVar(const ParmVarDecl *param);
 
   /// \brief Creates a function-scope variable in the current function and
   /// returns its <result-id>.
@@ -308,9 +371,17 @@ public:
 
   /// \brief Returns the associated counter's (<result-id>, is-alias-or-not)
   /// pair for the given {RW|Append|Consume}StructuredBuffer variable.
-  /// Returns nullptr if the given decl has no associated counter variable
-  /// created.
-  const CounterIdAliasPair *getCounterIdAliasPair(const DeclaratorDecl *decl);
+  /// If indices is not nullptr, walks trhough the fields of the decl, expected
+  /// to be of struct type, using the indices to find the field. Returns nullptr
+  /// if the given decl has no associated counter variable created.
+  const CounterIdAliasPair *getCounterIdAliasPair(
+      const DeclaratorDecl *decl,
+      const llvm::SmallVector<uint32_t, 4> *indices = nullptr) const;
+
+  /// \brief Returns all the associated counters for the given decl. The decl is
+  /// expected to be a struct containing alias RW/Append/Consume structured
+  /// buffers. Returns nullptr if it does not.
+  const CounterVarFields *getCounterVarFields(const DeclaratorDecl *decl) const;
 
   /// \brief Returns the <type-id> for the given cbuffer, tbuffer,
   /// ConstantBuffer, TextureBuffer, or push constant block.
@@ -484,14 +555,29 @@ private:
   bool validateVKBuiltins(const DeclaratorDecl *decl,
                           const hlsl::SigPoint *sigPoint);
 
-  /// Creates the associated counter variable for RW/Append/Consume
-  /// structured buffer.
+  /// Methods for creating counter variables associated with the given decl.
+
+  /// Creates assoicated counter variables for all AssocCounter cases (see the
+  /// comment of CounterVarFields). fields.
+  void createCounterVarForDecl(const DeclaratorDecl *decl);
+  /// Creates the associated counter variable for final RW/Append/Consume
+  /// structured buffer. Handles AssocCounter#1 and AssocCounter#2 (see the
+  /// comment of CounterVarFields).
   ///
   /// The counter variable will be created as an alias variable (of
   /// pointer-to-pointer type in Private storage class) if isAlias is true.
   ///
   /// Note: isAlias - legalization specific code
-  void createCounterVar(const DeclaratorDecl *decl, bool isAlias);
+  void
+  createCounterVar(const DeclaratorDecl *decl, bool isAlias,
+                   const llvm::SmallVector<uint32_t, 4> *indices = nullptr);
+  /// Creates all assoicated counter variables by recursively visiting decl's
+  /// fields. Handles AssocCounter#3 and AssocCounter#4 (see the comment of
+  /// CounterVarFields).
+  inline void createFieldCounterVars(const DeclaratorDecl *decl);
+  void createFieldCounterVars(const DeclaratorDecl *rootDecl,
+                              const DeclaratorDecl *decl,
+                              llvm::SmallVector<uint32_t, 4> *indices);
 
   /// Decorates varId of the given asType with proper interpolation modes
   /// considering the attributes on the given decl.
@@ -532,7 +618,11 @@ private:
   llvm::SmallVector<ResourceVar, 8> resourceVars;
   /// Mapping from {RW|Append|Consume}StructuredBuffers to their
   /// counter variables' (<result-id>, is-alias-or-not) pairs
+  ///
+  /// conterVars holds entities of AssocCounter#1, fieldCounterVars holds
+  /// entities of the rest.
   llvm::DenseMap<const DeclaratorDecl *, CounterIdAliasPair> counterVars;
+  llvm::DenseMap<const DeclaratorDecl *, CounterVarFields> fieldCounterVars;
 
   /// Mapping from cbuffer/tbuffer/ConstantBuffer/TextureBufer/push-constant
   /// to the <type-id>
@@ -593,6 +683,13 @@ public:
   GlPerVertex glPerVertex;
 };
 
+void CounterIdAliasPair::assign(const CounterIdAliasPair &srcPair,
+                                ModuleBuilder &builder,
+                                TypeTranslator &translator) const {
+  assert(isAlias);
+  builder.createStore(resultId, srcPair.get(builder, translator));
+}
+
 DeclResultIdMapper::DeclResultIdMapper(const hlsl::ShaderModel &model,
                                        ASTContext &context,
                                        ModuleBuilder &builder,
@@ -613,6 +710,15 @@ bool DeclResultIdMapper::isInputStorageClass(const StageVar &v) {
          spv::StorageClass::Input;
 }
 
+void DeclResultIdMapper::createFnParamCounterVar(const ParmVarDecl *param) {
+  return createCounterVarForDecl(param);
+}
+
+void DeclResultIdMapper::createFieldCounterVars(const DeclaratorDecl *decl) {
+  llvm::SmallVector<uint32_t, 4> indices;
+  createFieldCounterVars(decl, decl, &indices);
+}
+
 } // end namespace spirv
 } // end namespace clang
 

+ 95 - 24
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -1596,9 +1596,8 @@ SpirvEvalInfo SPIRVEmitter::doBinaryOperator(const BinaryOperator *expr) {
   // Handle assignment first since we need to evaluate rhs before lhs.
   // For other binary operations, we need to evaluate lhs before rhs.
   if (opcode == BO_Assign) {
-    if (const auto *dstDecl = getReferencedDef(expr->getLHS()))
-      // Update counter variable associated with lhs of assignments
-      tryToAssignCounterVar(dstDecl, expr->getRHS());
+    // Update counter variable associated with lhs of assignments
+    tryToAssignCounterVar(expr->getLHS(), expr->getRHS());
 
     return processAssignment(expr->getLHS(), loadIfGLValue(expr->getRHS()),
                              /*isCompoundAssignment=*/false);
@@ -2771,13 +2770,13 @@ uint32_t SPIRVEmitter::incDecRWACSBufferCounter(const CXXMemberCallExpr *expr,
     (void)doExpr(object);
   }
 
-  const auto *buffer = getReferencedDef(object);
-  if (!buffer) {
-    emitError("method call syntax unimplemented", expr->getExprLoc());
+  const auto *counterPair = getFinalACSBufferCounter(object);
+  if (!counterPair) {
+    emitFatalError("cannot find the associated counter variable",
+                   object->getExprLoc());
     return 0;
   }
 
-  const auto &counterPair = declIdMapper.getCounterIdAliasPair(buffer);
   const uint32_t counterPtrType = theBuilder.getPointerType(
       theBuilder.getInt32Type(), spv::StorageClass::Uniform);
   const uint32_t counterPtr = theBuilder.createAccessChain(
@@ -2800,37 +2799,109 @@ uint32_t SPIRVEmitter::incDecRWACSBufferCounter(const CXXMemberCallExpr *expr,
 
 bool SPIRVEmitter::tryToAssignCounterVar(const DeclaratorDecl *dstDecl,
                                          const Expr *srcExpr) {
+  // We are handling associated counters here. Casts should not alter which
+  // associated counter to manipulate.
+  srcExpr = srcExpr->IgnoreParenCasts();
+
   // For parameters of forward-declared functions. We must make sure the
   // associated counter variable is created. But for forward-declared functions,
   // the translation of the real definition may not be started yet.
   if (const auto *param = dyn_cast<ParmVarDecl>(dstDecl))
     declIdMapper.createFnParamCounterVar(param);
 
-  if (TypeTranslator::isRWAppendConsumeSBuffer(getTypeOrFnRetType(dstDecl))) {
-    // Internal RW/Append/Consume StructuredBuffer. We also need to
-    // initialize the associated counter.
-    const auto *srcPair =
-        declIdMapper.getCounterIdAliasPair(getReferencedDef(srcExpr));
-    const auto *dstPair = declIdMapper.getCounterIdAliasPair(dstDecl);
-
+  // Handle AssocCounter#1 (see CounterVarFields comment)
+  if (const auto *dstPair = declIdMapper.getCounterIdAliasPair(dstDecl)) {
+    const auto *srcPair = getFinalACSBufferCounter(srcExpr);
     if (!srcPair) {
-      emitFatalError(
-          "cannot handle counter variable associated with the given expr",
-          srcExpr->getLocStart())
-          << srcExpr->getSourceRange();
+      emitFatalError("cannot find the associated counter variable",
+                     srcExpr->getExprLoc());
       return false;
     }
-    if (!dstDecl) {
-      emitFatalError(
-          "cannot handle counter variable associated with the given decl",
-          dstDecl->getLocation());
-      return false;
+    dstPair->assign(*srcPair, theBuilder, typeTranslator);
+    return true;
+  }
+
+  // AssocCounter#2 for the lhs cannot happen since the lhs is a stand-alone
+  // decl in this method.
+
+  // Handle AssocCounter#3
+  if (const auto *dstFields = declIdMapper.getCounterVarFields(dstDecl)) {
+    if (const auto *srcDecl = getReferencedDef(srcExpr)) {
+      const auto *srcFields = declIdMapper.getCounterVarFields(srcDecl);
+      if (!srcFields) {
+        emitFatalError("cannot find the associated counter variable",
+                       srcExpr->getExprLoc());
+        return false;
+      }
+      dstFields->assign(*srcFields, theBuilder, typeTranslator);
+      return true;
     }
+  }
+
+  // Handle AssocCounter#4: TODO
+
+  return true;
+}
 
+bool SPIRVEmitter::tryToAssignCounterVar(const Expr *dstExpr,
+                                         const Expr *srcExpr) {
+  dstExpr = dstExpr->IgnoreParenCasts();
+  srcExpr = srcExpr->IgnoreParenCasts();
+
+  const auto *dstPair = getFinalACSBufferCounter(dstExpr);
+  const auto *srcPair = getFinalACSBufferCounter(srcExpr);
+
+  if ((dstPair == nullptr) != (srcPair == nullptr)) {
+    emitFatalError("cannot handle associated counter variable assignment",
+                   srcExpr->getExprLoc());
+    return false;
+  }
+
+  // Handle AssocCounter#1 & AssocCounter#2
+  if (dstPair && srcPair) {
     dstPair->assign(*srcPair, theBuilder, typeTranslator);
+    return true;
   }
 
-  return true;
+  // Handle AssocCounter#3
+  if (const auto *dstDecl = getReferencedDef(dstExpr))
+    if (const auto *dstFields = declIdMapper.getCounterVarFields(dstDecl)) {
+      const auto *srcDecl = getReferencedDef(srcExpr);
+      if (!srcDecl) {
+        emitFatalError("cannot find the associated counter variable",
+                       srcExpr->getExprLoc());
+        return false;
+      }
+
+      const auto *srcFields = declIdMapper.getCounterVarFields(srcDecl);
+      if (!srcFields) {
+        emitFatalError("cannot find the associated counter variable",
+                       srcExpr->getExprLoc());
+        return false;
+      }
+
+      dstFields->assign(*srcFields, theBuilder, typeTranslator);
+      return true;
+    }
+
+  // Handle AssocCounter#4: TODO
+
+  return false;
+}
+
+const CounterIdAliasPair *
+SPIRVEmitter::getFinalACSBufferCounter(const Expr *expr) {
+  // AssocCounter#1: referencing some stand-alone variable
+  if (const auto *decl = getReferencedDef(expr))
+    return declIdMapper.getCounterIdAliasPair(decl);
+
+  // AssocCounter#2: referencing some non-struct field
+  llvm::SmallVector<uint32_t, 4> indices;
+  if (const auto *decl = getReferencedDef(
+          collectArrayStructIndices(expr, &indices, /*rawIndex=*/true)))
+    return declIdMapper.getCounterIdAliasPair(decl, &indices);
+
+  return nullptr;
 }
 
 SpirvEvalInfo

+ 8 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -691,6 +691,14 @@ private:
   /// Note: legalization specific code
   bool tryToAssignCounterVar(const DeclaratorDecl *dstDecl,
                              const Expr *srcExpr);
+  bool tryToAssignCounterVar(const Expr *dstExpr, const Expr *srcExpr);
+
+  /// Returns the counter variable's information associated with the entity
+  /// represented by the given decl.
+  ///
+  /// This method only handles final alias structured buffers, which means
+  /// AssocCounter#1 and AssocCounter#2.
+  const CounterIdAliasPair *getFinalACSBufferCounter(const Expr *decl);
 
   /// \brief Loads numWords 32-bit unsigned integers or stores numWords 32-bit
   /// unsigned integers (based on the doStore parameter) to the given

+ 207 - 0
tools/clang/test/CodeGenSPIRV/spirv.legal.sbuffer.counter.struct.hlsl

@@ -0,0 +1,207 @@
+// Run: %dxc -T vs_6_0 -E main
+
+struct Bundle {
+      RWStructuredBuffer<float> rw;
+  AppendStructuredBuffer<float> append;
+ ConsumeStructuredBuffer<float> consume;
+};
+
+struct TwoBundle {
+    Bundle b1;
+    Bundle b2;
+};
+
+struct Wrapper {
+    TwoBundle b;
+};
+
+      RWStructuredBuffer<float> globalRWSBuffer;
+  AppendStructuredBuffer<float> globalASBuffer;
+ ConsumeStructuredBuffer<float> globalCSBuffer;
+
+Bundle  CreateBundle();
+Wrapper CreateWrapper();
+Wrapper ReturnWrapper(Wrapper wrapper);
+
+// Static variable
+static Bundle  staticBundle  = CreateBundle();
+static Wrapper staticWrapper = CreateWrapper();
+
+// CHECK: %counter_var_staticBundle_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_staticBundle_1 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_staticBundle_2 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+
+// CHECK: %counter_var_staticWrapper_0_0_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_staticWrapper_0_0_1 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_staticWrapper_0_0_2 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_staticWrapper_0_1_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_staticWrapper_0_1_1 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_staticWrapper_0_1_2 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+
+// CHECK: %counter_var_CreateBundle_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_CreateBundle_1 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_CreateBundle_2 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+
+// CHECK: %counter_var_CreateWrapper_0_0_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_CreateWrapper_0_0_1 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_CreateWrapper_0_0_2 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_CreateWrapper_0_1_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_CreateWrapper_0_1_1 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_CreateWrapper_0_1_2 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+
+// CHECK: %counter_var_localWrapper_0_0_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_localWrapper_0_0_1 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_localWrapper_0_0_2 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_localWrapper_0_1_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_localWrapper_0_1_1 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_localWrapper_0_1_2 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+
+// CHECK: %counter_var_wrapper_0_0_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_wrapper_0_0_1 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_wrapper_0_0_2 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_wrapper_0_1_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_wrapper_0_1_1 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_wrapper_0_1_2 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+
+// CHECK: %counter_var_ReturnWrapper_0_0_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_ReturnWrapper_0_0_1 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_ReturnWrapper_0_0_2 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_ReturnWrapper_0_1_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_ReturnWrapper_0_1_1 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_ReturnWrapper_0_1_2 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+
+// CHECK: %counter_var_b_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_b_1 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_b_2 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+
+// CHECK: %counter_var_w_0_0_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_w_0_0_1 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_w_0_0_2 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_w_0_1_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_w_0_1_1 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_w_0_1_2 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+
+// CHECK-LABEL: %main = OpFunction
+
+    // Assign to static variable
+// CHECK:      [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_CreateBundle_0
+// CHECK-NEXT:                OpStore %counter_var_staticBundle_0 [[src]]
+// CHECK-NEXT: [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_CreateBundle_1
+// CHECK-NEXT:                OpStore %counter_var_staticBundle_1 [[src]]
+// CHECK-NEXT: [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_CreateBundle_2
+// CHECK-NEXT:                OpStore %counter_var_staticBundle_2 [[src]]
+
+// CHECK-LABEL: %src_main = OpFunction
+float main() : VALUE {
+    // Assign to the parameter
+// CHECK:      [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_staticWrapper_0_0_0
+// CHECK-NEXT:                OpStore %counter_var_wrapper_0_0_0 [[src]]
+// CHECK-NEXT: [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_staticWrapper_0_0_1
+// CHECK-NEXT:                OpStore %counter_var_wrapper_0_0_1 [[src]]
+// CHECK-NEXT: [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_staticWrapper_0_0_2
+// CHECK-NEXT:                OpStore %counter_var_wrapper_0_0_2 [[src]]
+// CHECK-NEXT: [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_staticWrapper_0_1_0
+// CHECK-NEXT:                OpStore %counter_var_wrapper_0_1_0 [[src]]
+// CHECK-NEXT: [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_staticWrapper_0_1_1
+// CHECK-NEXT:                OpStore %counter_var_wrapper_0_1_1 [[src]]
+// CHECK-NEXT: [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_staticWrapper_0_1_2
+// CHECK-NEXT:                OpStore %counter_var_wrapper_0_1_2 [[src]]
+    // Make the call
+// CHECK:          {{%\d+}} = OpFunctionCall %Wrapper %ReturnWrapper %param_var_wrapper
+    // Assign to the return value
+// CHECK:      [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_ReturnWrapper_0_0_0
+// CHECK-NEXT:                OpStore %counter_var_localWrapper_0_0_0 [[src]]
+// CHECK-NEXT: [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_ReturnWrapper_0_0_1
+// CHECK-NEXT:                OpStore %counter_var_localWrapper_0_0_1 [[src]]
+// CHECK-NEXT: [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_ReturnWrapper_0_0_2
+// CHECK-NEXT:                OpStore %counter_var_localWrapper_0_0_2 [[src]]
+// CHECK-NEXT: [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_ReturnWrapper_0_1_0
+// CHECK-NEXT:                OpStore %counter_var_localWrapper_0_1_0 [[src]]
+// CHECK-NEXT: [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_ReturnWrapper_0_1_1
+// CHECK-NEXT:                OpStore %counter_var_localWrapper_0_1_1 [[src]]
+// CHECK-NEXT: [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_ReturnWrapper_0_1_2
+// CHECK-NEXT:                OpStore %counter_var_localWrapper_0_1_2 [[src]]
+    Wrapper localWrapper = ReturnWrapper(staticWrapper);
+
+// CHECK:      [[cnt:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localWrapper_0_0_0
+// CHECK-NEXT: [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_int [[cnt]] %uint_0
+// CHECK-NEXT:     {{%\d+}} = OpAtomicIAdd %int [[ptr]] %uint_1 %uint_0 %int_1
+    localWrapper.b.b1.rw.IncrementCounter();
+// CHECK:      [[cnt:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localWrapper_0_1_1
+// CHECK-NEXT: [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_int [[cnt]] %uint_0
+// CHECK-NEXT: [[add:%\d+]] = OpAtomicIAdd %int [[ptr]] %uint_1 %uint_0 %int_1
+// CHECK-NEXT:     {{%\d+}} = OpAccessChain %_ptr_Uniform_float {{%\d+}} %uint_0 [[add]]
+    localWrapper.b.b2.append.Append(5.0);
+
+// CHECK:      [[cnt:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_ReturnWrapper_0_0_2
+// CHECK-NEXT: [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_int [[cnt]] %uint_0
+// CHECK-NEXT: [[sub:%\d+]] = OpAtomicISub %int [[ptr]] %uint_1 %uint_0 %int_1
+// CHECK-NEXT: [[pre:%\d+]] = OpISub %int [[sub]] %int_1
+// CHECK-NEXT:     {{%\d+}} = OpAccessChain %_ptr_Uniform_float {{%\d+}} %uint_0 [[pre]]
+    return ReturnWrapper(staticWrapper).b.b1.consume.Consume();
+}
+
+// CHECK-LABEL: %CreateBundle = OpFunction
+Bundle CreateBundle() {
+    Bundle b;
+    // Assign to final struct fields who have associated counters
+// CHECK: OpStore %counter_var_b_0 %counter_var_globalRWSBuffer
+    b.rw      = globalRWSBuffer;
+// CHECK: OpStore %counter_var_b_1 %counter_var_globalASBuffer
+    b.append  = globalASBuffer;
+// CHECK: OpStore %counter_var_b_2 %counter_var_globalCSBuffer
+    b.consume = globalCSBuffer;
+
+    // Assign from local variable
+// CHECK:      [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_b_0
+// CHECK-NEXT:                OpStore %counter_var_CreateBundle_0 [[src]]
+// CHECK-NEXT: [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_b_1
+// CHECK-NEXT:                OpStore %counter_var_CreateBundle_1 [[src]]
+// CHECK-NEXT: [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_b_2
+// CHECK-NEXT:                OpStore %counter_var_CreateBundle_2 [[src]]
+    return b;
+}
+
+// CHECK-LABEL: %CreateWrapper = OpFunction
+Wrapper CreateWrapper() {
+    Wrapper w;
+
+    // Assign from final struct fields who have associated counters
+// CHECK:      [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_staticBundle_0
+// CHECK-NEXT:                OpStore %counter_var_w_0_0_0 [[src]]
+    w.b.b1.rw      = staticBundle.rw;
+// CHECK:      [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_staticBundle_1
+// CHECK-NEXT:                OpStore %counter_var_w_0_0_1 [[src]]
+    w.b.b1.append  = staticBundle.append;
+// CHECK:      [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_staticBundle_2
+// CHECK-NEXT:                OpStore %counter_var_w_0_0_2 [[src]]
+    w.b.b1.consume = staticBundle.consume;
+
+    // TODO:
+
+    // Assign to intermediate structs whose fields have associated counters
+    //w.b.b2         = staticBundle;
+
+    // Assign from intermediate structs whose fields have associated counters
+    //staticBundle   = w.b.b1;
+
+    return w;
+}
+
+// CHECK-LABEL: %ReturnWrapper = OpFunction
+Wrapper ReturnWrapper(Wrapper wrapper) {
+    // Assign from parameter
+// CHECK:      [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_wrapper_0_0_0
+// CHECK-NEXT:                OpStore %counter_var_ReturnWrapper_0_0_0 [[src]]
+// CHECK-NEXT: [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_wrapper_0_0_1
+// CHECK-NEXT:                OpStore %counter_var_ReturnWrapper_0_0_1 [[src]]
+// CHECK-NEXT: [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_wrapper_0_0_2
+// CHECK-NEXT:                OpStore %counter_var_ReturnWrapper_0_0_2 [[src]]
+// CHECK-NEXT: [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_wrapper_0_1_0
+// CHECK-NEXT:                OpStore %counter_var_ReturnWrapper_0_1_0 [[src]]
+// CHECK-NEXT: [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_wrapper_0_1_1
+// CHECK-NEXT:                OpStore %counter_var_ReturnWrapper_0_1_1 [[src]]
+// CHECK-NEXT: [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_wrapper_0_1_2
+// CHECK-NEXT:                OpStore %counter_var_ReturnWrapper_0_1_2 [[src]]
+    return wrapper;
+}

+ 1 - 6
tools/clang/test/CodeGenSPIRV/spirv.legal.sbuffer.struct.hlsl

@@ -84,15 +84,10 @@ float4 main() : SV_Target {
     int index = c.getT().getSBuffer2()[42];
 
 // CHECK:      [[val:%\d+]] = OpLoad %Combine %c
-// CHECK-NEXT:                OpStore %param_var_comb [[val]]
+// CHECK:                     OpStore %param_var_comb [[val]]
     return foo(c);
 }
 float4 foo(Combine comb) {
-    // TODO: add support for associated counters of struct fields
-    // comb.s.append.Append(float4(1, 2, 3, 4));
-    // float4 val = comb.s.consume.Consume();
-    // comb.t.rw[5].a = 4.2;
-
 // CHECK:      [[ptr1:%\d+]] = OpAccessChain %_ptr_Function__ptr_Uniform_type_ByteAddressBuffer %comb %int_2
 // CHECK-NEXT: [[ptr2:%\d+]] = OpLoad %_ptr_Uniform_type_ByteAddressBuffer [[ptr1]]
 // CHECK-NEXT:  [[idx:%\d+]] = OpShiftRightLogical %uint %uint_5 %uint_2

+ 4 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -1005,6 +1005,10 @@ TEST_F(FileTest, SpirvLegalizationStructuredBufferCounter) {
   runFileTest("spirv.legal.sbuffer.counter.hlsl", Expect::Success,
               /*runValidation=*/true, /*relaxLogicalPointer=*/true);
 }
+TEST_F(FileTest, SpirvLegalizationStructuredBufferCounterInStruct) {
+  runFileTest("spirv.legal.sbuffer.counter.struct.hlsl", Expect::Success,
+              /*runValidation=*/true, /*relaxLogicalPointer=*/true);
+}
 TEST_F(FileTest, SpirvLegalizationStructuredBufferInStruct) {
   runFileTest("spirv.legal.sbuffer.struct.hlsl", Expect::Success,
               /*runValidation=*/true, /*relaxLogicalPointer=*/true);