Bläddra i källkod

[spirv] Emit counter variables for alias structured/byte buffers (#923)

Also handled updating counter variable aliases in assignment and
function calls.
Lei Zhang 7 år sedan
förälder
incheckning
a269cf44fe

+ 92 - 36
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -245,27 +245,41 @@ SpirvEvalInfo DeclResultIdMapper::getDeclResultId(const ValueDecl *decl,
 
 uint32_t DeclResultIdMapper::createFnParam(const ParmVarDecl *param) {
   bool isAlias = false;
-  const uint32_t type = getTypeForPotentialAliasVar(param, &isAlias);
+  auto &info = astDecls[param].info;
+  const uint32_t type = getTypeForPotentialAliasVar(param, &isAlias, &info);
   const uint32_t ptrType =
       theBuilder.getPointerType(type, spv::StorageClass::Function);
   const uint32_t id = theBuilder.addFnParam(ptrType, param->getName());
-  astDecls[param] = SpirvEvalInfo(id)
-                        .setStorageClass(isAlias ? spv::StorageClass::Uniform
-                                                 : spv::StorageClass::Function)
-                        .setValTypeId(isAlias ? type : 0);
+  info.setResultId(id);
+
+  // The counter variable may be created before by forward declaration.
+  if (!counterVars.count(param))
+    // Create alias counter variable if suitable
+    if (isAlias && TypeTranslator::isRWAppendConsumeSBuffer(param->getType()))
+      createCounterVar(param, /*isAlias=*/true);
 
   return id;
 }
 
+void DeclResultIdMapper::createFnParamCounterVar(const ParmVarDecl *param) {
+  if (counterVars.count(param))
+    return;
+
+  if (TypeTranslator::isRWAppendConsumeSBuffer(param->getType()))
+    createCounterVar(param, /*isAlias=*/true);
+}
+
 uint32_t DeclResultIdMapper::createFnVar(const VarDecl *var,
                                          llvm::Optional<uint32_t> init) {
   bool isAlias = false;
-  const uint32_t type = getTypeForPotentialAliasVar(var, &isAlias);
+  auto &info = astDecls[var].info;
+  const uint32_t type = getTypeForPotentialAliasVar(var, &isAlias, &info);
   const uint32_t id = theBuilder.addFnVar(type, var->getName(), init);
-  astDecls[var] = SpirvEvalInfo(id)
-                      .setStorageClass(isAlias ? spv::StorageClass::Uniform
-                                               : spv::StorageClass::Function)
-                      .setValTypeId(isAlias ? type : 0);
+  info.setResultId(id);
+
+  // Create alias counter variable if suitable
+  if (isAlias && TypeTranslator::isRWAppendConsumeSBuffer(var->getType()))
+    createCounterVar(var, /*isAlias=*/true);
 
   return id;
 }
@@ -273,13 +287,17 @@ uint32_t DeclResultIdMapper::createFnVar(const VarDecl *var,
 uint32_t DeclResultIdMapper::createFileVar(const VarDecl *var,
                                            llvm::Optional<uint32_t> init) {
   bool isAlias = false;
-  const uint32_t type = getTypeForPotentialAliasVar(var, &isAlias);
+  auto &info = astDecls[var].info;
+  const uint32_t type = getTypeForPotentialAliasVar(var, &isAlias, &info);
   const uint32_t id = theBuilder.addModuleVar(type, spv::StorageClass::Private,
                                               var->getName(), init);
-  astDecls[var] = SpirvEvalInfo(id)
-                      .setStorageClass(isAlias ? spv::StorageClass::Uniform
-                                               : spv::StorageClass::Private)
-                      .setValTypeId(isAlias ? type : 0);
+  info.setResultId(id);
+  if (!isAlias)
+    info.setStorageClass(spv::StorageClass::Private);
+
+  // Create alias counter variable if suitable
+  if (isAlias && TypeTranslator::isRWAppendConsumeSBuffer(var->getType()))
+    createCounterVar(var, /*isAlias=*/true);
 
   return id;
 }
@@ -327,7 +345,7 @@ uint32_t DeclResultIdMapper::createExternVar(const VarDecl *var) {
   if (isACRWSBuffer) {
     // For {Append|Consume|RW}StructuredBuffer, we need to always create another
     // variable for its associated counter.
-    createCounterVar(var);
+    createCounterVar(var, /*isAlias=*/false);
   }
 
   return id;
@@ -481,31 +499,62 @@ uint32_t DeclResultIdMapper::getOrRegisterFnResultId(const FunctionDecl *fn) {
   if (const auto *info = getDeclSpirvInfo(fn))
     return info->info;
 
+  auto &info = astDecls[fn].info;
+
+  bool isAlias = false;
+  const uint32_t type = getTypeForPotentialAliasVar(fn, &isAlias, &info);
+
   const uint32_t id = theBuilder.getSPIRVContext()->takeNextId();
-  astDecls[fn] = SpirvEvalInfo(id);
+  info.setResultId(id);
+  if (isAlias)
+    // No need to dereference to get the pointer. Alias function returns
+    // themselves are already pointers to values.
+    info.setValTypeId(0);
+  else
+    // All other cases should be normal rvalues.
+    info.setRValue();
+
+  // Create alias counter variable if suitable
+  if (TypeTranslator::isRWAppendConsumeSBuffer(fn->getReturnType()))
+    createCounterVar(fn, /*isAlias=*/true);
 
   return id;
 }
 
-uint32_t DeclResultIdMapper::getOrCreateCounterId(const ValueDecl *decl) {
+const CounterIdAliasPair &
+DeclResultIdMapper::getCounterIdAliasPair(const ValueDecl *decl) {
   const auto counter = counterVars.find(decl);
-  if (counter != counterVars.end())
-    return counter->second;
-  return createCounterVar(decl);
+  assert(counter != counterVars.end());
+  return counter->second;
 }
 
-uint32_t DeclResultIdMapper::createCounterVar(const ValueDecl *decl) {
-  const auto *info = getDeclSpirvInfo(decl);
-  const uint32_t counterType = typeTranslator.getACSBufferCounter();
+void DeclResultIdMapper::createCounterVar(const ValueDecl *decl, bool isAlias) {
   const std::string counterName = "counter.var." + decl->getName().str();
-  const uint32_t counterId = theBuilder.addModuleVar(
-      counterType, info->info.getStorageClass(), counterName);
-
-  resourceVars.emplace_back(counterId, ResourceVar::Category::Other,
-                            getResourceBinding(decl),
-                            decl->getAttr<VKBindingAttr>(),
-                            decl->getAttr<VKCounterBindingAttr>(), true);
-  return counterVars[decl] = counterId;
+  uint32_t counterType = typeTranslator.getACSBufferCounter();
+  // {RW|Append|Consume}StructuredBuffer are all in Uniform storage class.
+  // Alias counter variables should be created into the Private storage class.
+  const spv::StorageClass sc =
+      isAlias ? spv::StorageClass::Private : spv::StorageClass::Uniform;
+
+  if (isAlias) {
+    // Apply an extra level of pointer for alias counter variable
+    counterType =
+        theBuilder.getPointerType(counterType, spv::StorageClass::Uniform);
+  }
+
+  const uint32_t counterId =
+      theBuilder.addModuleVar(counterType, sc, counterName);
+
+  if (!isAlias) {
+    // Non-alias counter variables should be put in to resourceVars so that
+    // descriptors can be allocated for them.
+    resourceVars.emplace_back(counterId, ResourceVar::Category::Other,
+                              getResourceBinding(decl),
+                              decl->getAttr<VKBindingAttr>(),
+                              decl->getAttr<VKCounterBindingAttr>(), true);
+  }
+
+  counterVars[decl] = {counterId, isAlias};
 }
 
 uint32_t
@@ -1804,9 +1853,8 @@ DeclResultIdMapper::getStorageClassForSigPoint(const hlsl::SigPoint *sigPoint) {
   return sc;
 }
 
-uint32_t
-DeclResultIdMapper::getTypeForPotentialAliasVar(const DeclaratorDecl *decl,
-                                                bool *shouldBeAlias) {
+uint32_t DeclResultIdMapper::getTypeForPotentialAliasVar(
+    const DeclaratorDecl *decl, bool *shouldBeAlias, SpirvEvalInfo *info) {
   if (const auto *varDecl = dyn_cast<VarDecl>(decl)) {
     // This method is only intended to be used to create SPIR-V variables in the
     // Function or Private storage class.
@@ -1841,7 +1889,15 @@ DeclResultIdMapper::getTypeForPotentialAliasVar(const DeclaratorDecl *decl,
     const uint32_t valType = typeTranslator.translateType(type, rule);
     // All constant/texture/structured/byte buffers are in the Uniform
     // storage class.
-    return theBuilder.getPointerType(valType, spv::StorageClass::Uniform);
+    const auto ptrType =
+        theBuilder.getPointerType(valType, spv::StorageClass::Uniform);
+
+    if (info)
+      info->setStorageClass(spv::StorageClass::Uniform)
+          .setLayoutRule(rule)
+          .setValTypeId(ptrType);
+
+    return ptrType;
   }
 
   return typeTranslator.translateType(type);

+ 54 - 7
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -131,6 +131,38 @@ private:
   bool isCounterVar;                          ///< Couter variable or not
 };
 
+/// A (<result-id>, is-alias-or-not) pair for counter variables
+class CounterIdAliasPair {
+public:
+  /// Default constructor to satisfy llvm::DenseMap
+  CounterIdAliasPair() : resultId(0), isAlias(false) {}
+  CounterIdAliasPair(uint32_t id, bool alias) : resultId(id), isAlias(alias) {}
+
+  /// Returns the pointer to the counter variable. Dereferences first if this is
+  /// an alias to a counter variable.
+  uint32_t get(ModuleBuilder &builder, TypeTranslator &translator) const {
+    if (isAlias) {
+      const uint32_t counterVarType = builder.getPointerType(
+          translator.getACSBufferCounter(), spv::StorageClass::Uniform);
+      return builder.createLoad(counterVarType, resultId);
+    }
+    return resultId;
+  }
+
+  /// Stores the counter variable's pointer in srcPair to the curent counter
+  /// variable. The current counter variable must be an alias.
+  void assign(const CounterIdAliasPair &srcPair, ModuleBuilder &builder,
+              TypeTranslator &translator) const {
+    assert(isAlias);
+    builder.createStore(resultId, srcPair.get(builder, translator));
+  }
+
+private:
+  uint32_t resultId;
+  /// Note: legalization specific code
+  bool isAlias;
+};
+
 /// \brief The class containing mappings from Clang frontend Decls to their
 /// corresponding SPIR-V <result-id>s.
 ///
@@ -180,6 +212,12 @@ public:
   /// returns its <result-id>.
   uint32_t createFnParam(const ParmVarDecl *param);
 
+  /// \brief Creates the counter variable associated with the given param.
+  /// This is meant to be used for forward-declared functions.
+  ///
+  /// Note: legalization specific code
+  void createFnParamCounterVar(const ParmVarDecl *param);
+
   /// \brief Creates a function-scope variable in the current function and
   /// returns its <result-id>.
   uint32_t createFnVar(const VarDecl *var, llvm::Optional<uint32_t> init);
@@ -219,9 +257,13 @@ public:
   /// pointer-to-the-value type will be returned, otherwise, just return the
   /// normal value type.
   ///
+  /// If the type is for an alias variable, writes true to *shouldBeAlias and
+  /// writes storage class, layout rule, and valTypeId to *info.
+  ///
   /// Note: legalization specific code
   uint32_t getTypeForPotentialAliasVar(const DeclaratorDecl *var,
-                                       bool *shouldBeAlias = nullptr);
+                                       bool *shouldBeAlias = nullptr,
+                                       SpirvEvalInfo *info = nullptr);
 
   /// \brief Sets the <result-id> of the entry function.
   void setEntryFunctionId(uint32_t id) { entryFunctionId = id; }
@@ -262,9 +304,9 @@ public:
   /// returns a newly assigned <result-id> for it.
   uint32_t getOrRegisterFnResultId(const FunctionDecl *fn);
 
-  /// \brief Returns the associated counter's <result-id> for the given
-  /// {RW|Append|Consume}StructuredBuffer variable.
-  uint32_t getOrCreateCounterId(const ValueDecl *decl);
+  /// \brief Returns the associated counter's (<result-id>, is-alias-or-not)
+  /// pair for the given {RW|Append|Consume}StructuredBuffer variable.
+  const CounterIdAliasPair &getCounterIdAliasPair(const ValueDecl *decl);
 
   /// \brief Returns the <type-id> for the given cbuffer, tbuffer,
   /// ConstantBuffer, TextureBuffer, or push constant block.
@@ -440,7 +482,12 @@ private:
 
   /// Creates the associated counter variable for RW/Append/Consume
   /// structured buffer.
-  uint32_t createCounterVar(const ValueDecl *decl);
+  ///
+  /// The counter variable will be created as an alias variable (of
+  /// pointer-to-pointer type in Private storage class) if isAlias is true.
+  ///
+  /// Note: isAlias - legalization specific code
+  void createCounterVar(const ValueDecl *decl, bool isAlias);
 
   /// Decorates varId of the given asType with proper interpolation modes
   /// considering the attributes on the given decl.
@@ -480,8 +527,8 @@ private:
   /// Vector of all defined resource variables.
   llvm::SmallVector<ResourceVar, 8> resourceVars;
   /// Mapping from {RW|Append|Consume}StructuredBuffers to their
-  /// counter variables
-  llvm::DenseMap<const ValueDecl *, uint32_t> counterVars;
+  /// counter variables' (<result-id>, is-alias-or-not) pairs
+  llvm::DenseMap<const ValueDecl *, CounterIdAliasPair> counterVars;
 
   /// Mapping from cbuffer/tbuffer/ConstantBuffer/TextureBufer/push-constant
   /// to the <type-id>

+ 123 - 27
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -24,6 +24,15 @@ namespace spirv {
 
 namespace {
 
+/// Returns the type of the given decl. If the given decl is a FunctionDecl,
+/// returns its result type.
+inline QualType getTypeOrFnRetType(const ValueDecl *decl) {
+  if (const auto *funcDecl = dyn_cast<FunctionDecl>(decl)) {
+    return funcDecl->getReturnType();
+  }
+  return decl->getType();
+}
+
 // Returns true if the given decl has the given semantic.
 bool hasSemantic(const DeclaratorDecl *decl,
                  hlsl::DXIL::SemanticKind semanticKind) {
@@ -205,16 +214,23 @@ bool isConstantTextureBufferLoad(const Expr *expr) {
   return false;
 }
 
-/// Returns true if the given expr is an DeclRefExpr referencing a kind of
-/// structured or byte buffer and it is non-alias one.
+/// Returns true if
+/// * the given expr is an DeclRefExpr referencing a kind of structured or byte
+/// buffer and it is non-alias one, or
+/// * the given expr is an CallExpr returning a kind of structured or byte
+/// buffer.
 ///
 /// Note: legalization specific code
 bool isReferencingNonAliasStructuredOrByteBuffer(const Expr *expr) {
-  if (const auto *declRefExpr = dyn_cast<DeclRefExpr>(expr->IgnoreParenCasts()))
+  expr = expr->IgnoreParenCasts();
+  if (const auto *declRefExpr = dyn_cast<DeclRefExpr>(expr)) {
     if (const auto *varDecl = dyn_cast<VarDecl>(declRefExpr->getFoundDecl()))
-      if (TypeTranslator::isAKindOfStructuredOrByteBuffer(varDecl->getType())) {
+      if (TypeTranslator::isAKindOfStructuredOrByteBuffer(varDecl->getType()))
         return isExternalVar(varDecl);
-      }
+  } else if (const auto *callExpr = dyn_cast<CallExpr>(expr)) {
+    if (TypeTranslator::isAKindOfStructuredOrByteBuffer(callExpr->getType()))
+      return true;
+  }
   return false;
 }
 
@@ -331,6 +347,41 @@ inline const HLSLBufferDecl *getCTBufferContext(const VarDecl *varDecl) {
   return nullptr;
 }
 
+/// Returns the real definition of the callee of the given CallExpr.
+///
+/// If we are calling a forward-declared function, callee will be the
+/// FunctionDecl for the foward-declared function, not the actual
+/// definition. The foward-delcaration and defintion are two completely
+/// different AST nodes.
+inline const FunctionDecl *getCalleeDefinition(const CallExpr *expr) {
+  const auto *callee = expr->getDirectCallee();
+
+  if (callee->isThisDeclarationADefinition())
+    return callee;
+
+  // We need to update callee to the actual definition here
+  if (!callee->isDefined(callee))
+    return nullptr;
+
+  return callee;
+}
+
+/// Returns the referenced definition. The given expr is expected to be a
+/// DeclRefExpr or CallExpr after ignoring casts. Returns nullptr otherwise.
+const ValueDecl *getReferencedDef(const Expr *expr) {
+  expr = expr->IgnoreParenCasts();
+
+  if (const auto *declRefExpr = dyn_cast<DeclRefExpr>(expr)) {
+    return declRefExpr->getDecl();
+  }
+
+  if (const auto *callExpr = dyn_cast<CallExpr>(expr)) {
+    return getCalleeDefinition(callExpr);
+  }
+
+  return nullptr;
+}
+
 } // namespace
 
 SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci,
@@ -603,6 +654,10 @@ SpirvEvalInfo SPIRVEmitter::loadIfGLValue(const Expr *expr) {
     // buffer as a whole. If true, it means we are creating alias for it. Avoid
     // the load and write the pointer directly to the alias variable then.
     //
+    // Also for the case of alias function returns. If we are trying to load an
+    // alias function return as a whole, it means we are assigning it to another
+    // alias variable. Avoid the load and write the pointer directly.
+    //
     // Note: legalization specific code
     if (isReferencingNonAliasStructuredOrByteBuffer(expr)) {
       return info.setRValue();
@@ -927,6 +982,9 @@ void SPIRVEmitter::doVarDecl(const VarDecl *decl) {
         theBuilder.createStore(varId, constId);
       else
         storeValue(varId, loadIfGLValue(init), decl->getType(), init);
+
+      // Update counter variable associatd with local variables
+      tryToAssignCounterVar(decl, init);
     }
   }
 
@@ -1369,6 +1427,9 @@ void SPIRVEmitter::doIfStmt(const IfStmt *ifStmt) {
 
 void SPIRVEmitter::doReturnStmt(const ReturnStmt *stmt) {
   if (const auto *retVal = stmt->getRetValue()) {
+    // Update counter variable associatd with function returns
+    tryToAssignCounterVar(curFunction, retVal);
+
     const auto retInfo = doExpr(retVal);
     const auto retType = retVal->getType();
     if (retInfo.getStorageClass() != spv::StorageClass::Function &&
@@ -1493,6 +1554,10 @@ SpirvEvalInfo SPIRVEmitter::doBinaryOperator(const BinaryOperator *expr) {
   // Handle assignment first since we need to evaluate rhs before lhs.
   // For other binary operations, we need to evaluate lhs before rhs.
   if (opcode == BO_Assign) {
+    if (const auto *dstDecl = getReferencedDef(expr->getLHS()))
+      // Update counter variable associatd with lhs of assignments
+      tryToAssignCounterVar(dstDecl, expr->getRHS());
+
     return processAssignment(expr->getLHS(), loadIfGLValue(expr->getRHS()),
                              /*isCompoundAssignment=*/false, /*lhsPtr=*/0,
                              expr->getRHS());
@@ -1528,24 +1593,12 @@ SpirvEvalInfo SPIRVEmitter::doCallExpr(const CallExpr *callExpr) {
 }
 
 SpirvEvalInfo SPIRVEmitter::processCall(const CallExpr *callExpr) {
-  const FunctionDecl *callee = callExpr->getDirectCallee();
+  const FunctionDecl *callee = getCalleeDefinition(callExpr);
 
-  // If we are calling a forward-declared function, callee will be the
-  // FunctionDecl for the foward-declared function, not the actual
-  // definition. The foward-delcaration and defintion are two completely
-  // different AST nodes.
   // Note that we always want the defintion because Stmts/Exprs in the
   // function body references the parameters in the definition.
-  if (!callee->isThisDeclarationADefinition()) {
-    // We need to update callee to the actual definition here
-    if (!callee->isDefined(callee)) {
-      emitError("found undefined function", callExpr->getExprLoc());
-      return 0;
-    }
-  }
-
   if (!callee) {
-    emitError("calling non-function unimplemented", callExpr->getExprLoc());
+    emitError("found undefined function", callExpr->getExprLoc());
     return 0;
   }
 
@@ -1602,6 +1655,9 @@ SpirvEvalInfo SPIRVEmitter::processCall(const CallExpr *callExpr) {
     params.push_back(tempVarId);
     args.push_back(doExpr(arg));
 
+    // Update counter variable associated with function parameters
+    tryToAssignCounterVar(param, arg);
+
     if (canActAsOutParmVar(param)) {
       // The current parameter is marked as out/inout. The argument then is
       // essentially passed in by reference. We need to load the value
@@ -1637,7 +1693,8 @@ SpirvEvalInfo SPIRVEmitter::processCall(const CallExpr *callExpr) {
     }
   }
 
-  return SpirvEvalInfo(retVal).setRValue();
+  // Inherit the SpirvEvalInfo from the function definition
+  return declIdMapper.getDeclResultId(callee).setResultId(retVal);
 }
 
 SpirvEvalInfo SPIRVEmitter::doCastExpr(const CastExpr *expr) {
@@ -2548,7 +2605,7 @@ SPIRVEmitter::processStructuredBufferLoad(const CXXMemberCallExpr *expr) {
 }
 
 uint32_t SPIRVEmitter::incDecRWACSBufferCounter(const CXXMemberCallExpr *expr,
-                                                bool isInc) {
+                                                bool isInc, bool loadObject) {
   const uint32_t i32Type = theBuilder.getInt32Type();
   const uint32_t one = theBuilder.getConstantUint32(1);  // As scope: Device
   const uint32_t zero = theBuilder.getConstantUint32(0); // As memory sema: None
@@ -2556,13 +2613,26 @@ uint32_t SPIRVEmitter::incDecRWACSBufferCounter(const CXXMemberCallExpr *expr,
 
   const auto *object =
       expr->getImplicitObjectArgument()->IgnoreParenNoopCasts(astContext);
-  const auto *buffer = cast<DeclRefExpr>(object)->getDecl();
 
-  const uint32_t counterVar = declIdMapper.getOrCreateCounterId(buffer);
+  if (loadObject) {
+    // We don't need the object's <result-id> here since counter variable is a
+    // separate variable. But we still need the side effects of evaluating the
+    // object, e.g., if the source code is foo(...).IncrementCounter(), we still
+    // want to emit the code for foo(...).
+    (void)doExpr(object);
+  }
+
+  const ValueDecl *buffer = getReferencedDef(object);
+  if (!buffer) {
+    emitError("method call syntax unimplemented", expr->getExprLoc());
+    return 0;
+  }
+
+  const auto &counterPair = declIdMapper.getCounterIdAliasPair(buffer);
   const uint32_t counterPtrType = theBuilder.getPointerType(
       theBuilder.getInt32Type(), spv::StorageClass::Uniform);
-  const uint32_t counterPtr =
-      theBuilder.createAccessChain(counterPtrType, counterVar, {zero});
+  const uint32_t counterPtr = theBuilder.createAccessChain(
+      counterPtrType, counterPair.get(theBuilder, typeTranslator), {zero});
 
   uint32_t index = 0;
   if (isInc) {
@@ -2579,6 +2649,25 @@ uint32_t SPIRVEmitter::incDecRWACSBufferCounter(const CXXMemberCallExpr *expr,
   return index;
 }
 
+void SPIRVEmitter::tryToAssignCounterVar(const ValueDecl *dstDecl,
+                                         const Expr *srcExpr) {
+  // For parameters of forward-declared functions. We must make sure the
+  // associated counter variable is created. But for forward-declared functions,
+  // the translation of the real definition may not be started yet.
+  if (const auto *param = dyn_cast<ParmVarDecl>(dstDecl))
+    declIdMapper.createFnParamCounterVar(param);
+
+  if (TypeTranslator::isRWAppendConsumeSBuffer(getTypeOrFnRetType(dstDecl))) {
+    // Internal RW/Append/Consume StructuredBuffer. We also need to
+    // initialize the associated counter.
+    const auto &srcPair =
+        declIdMapper.getCounterIdAliasPair(getReferencedDef(srcExpr));
+    const auto &dstPair = declIdMapper.getCounterIdAliasPair(dstDecl);
+
+    dstPair.assign(srcPair, theBuilder, typeTranslator);
+  }
+}
+
 SpirvEvalInfo
 SPIRVEmitter::processACSBufferAppendConsume(const CXXMemberCallExpr *expr) {
   const bool isAppend = expr->getNumArgs() == 1;
@@ -2587,9 +2676,13 @@ SPIRVEmitter::processACSBufferAppendConsume(const CXXMemberCallExpr *expr) {
 
   const auto *object =
       expr->getImplicitObjectArgument()->IgnoreParenNoopCasts(astContext);
-  auto bufferInfo = doDeclRefExpr(cast<DeclRefExpr>(object));
 
-  uint32_t index = incDecRWACSBufferCounter(expr, isAppend);
+  auto bufferInfo = loadIfAliasVarRef(object);
+
+  uint32_t index = incDecRWACSBufferCounter(
+      expr, isAppend,
+      // We have already translated the object in the above. Avoid duplication.
+      /*loadObject=*/false);
 
   const auto bufferElemTy = hlsl::GetHLSLResourceResultType(object->getType());
 
@@ -6918,6 +7011,9 @@ bool SPIRVEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
     const auto id = declIdMapper.getDeclResultId(varDecl);
     if (const auto *init = varDecl->getInit()) {
       theBuilder.createStore(id, doExpr(init));
+
+      // Update counter variable associatd with global variables
+      tryToAssignCounterVar(varDecl, init);
     } else {
       const auto typeId = typeTranslator.translateType(varDecl->getType());
       theBuilder.createStore(id, theBuilder.getConstantNull(typeId));

+ 11 - 2
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -660,8 +660,17 @@ private:
   SpirvEvalInfo processStructuredBufferLoad(const CXXMemberCallExpr *expr);
 
   /// \brief Increments or decrements the counter for RW/Append/Consume
-  /// structured buffer.
-  uint32_t incDecRWACSBufferCounter(const CXXMemberCallExpr *, bool isInc);
+  /// structured buffer. If loadObject is true, the object upon which the call
+  /// is made will be evaluated and translated into SPIR-V.
+  uint32_t incDecRWACSBufferCounter(const CXXMemberCallExpr *call, bool isInc,
+                                    bool loadObject = true);
+
+  /// Assigns the counter variable associated with srcExpr to the one associated
+  /// with dstDecl if the dstDecl is an internal RW/Append/Consume structured
+  /// buffer.
+  ///
+  /// Note: legalization specific code
+  void tryToAssignCounterVar(const ValueDecl *dstDecl, const Expr *srcExpr);
 
   /// \brief Loads numWords 32-bit unsigned integers or stores numWords 32-bit
   /// unsigned integers (based on the doStore parameter) to the given

+ 9 - 0
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -377,6 +377,15 @@ bool TypeTranslator::isConsumeStructuredBuffer(QualType type) {
   return name == "ConsumeStructuredBuffer";
 }
 
+bool TypeTranslator::isRWAppendConsumeSBuffer(QualType type) {
+  if (const RecordType *recordType = type->getAs<RecordType>()) {
+    StringRef name = recordType->getDecl()->getName();
+    return name == "RWStructuredBuffer" || name == "AppendStructuredBuffer" ||
+           name == "ConsumeStructuredBuffer";
+  }
+  return false;
+}
+
 bool TypeTranslator::isAKindOfStructuredOrByteBuffer(QualType type) {
   if (const RecordType *recordType = type->getAs<RecordType>()) {
     StringRef name = recordType->getDecl()->getName();

+ 4 - 0
tools/clang/lib/SPIRV/TypeTranslator.h

@@ -74,6 +74,10 @@ public:
   /// \brief Returns true if the given type is a ConsumeStructuredBuffer type.
   static bool isConsumeStructuredBuffer(QualType type);
 
+  /// \brief Returns true if the given type is a RW/Append/Consume
+  /// StructuredBuffer type.
+  static bool isRWAppendConsumeSBuffer(QualType type);
+
   /// \brief Returns true if the given type is the HLSL ByteAddressBufferType.
   static bool isByteAddressBuffer(QualType type);
 

+ 1 - 1
tools/clang/test/CodeGenSPIRV/bezier.hull.hlsl2spv

@@ -86,7 +86,6 @@ BEZIER_CONTROL_POINT SubDToBezierHS(InputPatch<VS_CONTROL_POINT_OUTPUT, MAX_POIN
 // OpName %BEZIER_CONTROL_POINT "BEZIER_CONTROL_POINT"
 // OpMemberName %BEZIER_CONTROL_POINT 0 "vPosition"
 // OpName %out_var_BEZIERPOS "out.var.BEZIERPOS"
-// OpName %SubDToBezierConstantsHS "SubDToBezierConstantsHS"
 // OpName %HS_CONSTANT_DATA_OUTPUT "HS_CONSTANT_DATA_OUTPUT"
 // OpMemberName %HS_CONSTANT_DATA_OUTPUT 0 "Edges"
 // OpMemberName %HS_CONSTANT_DATA_OUTPUT 1 "Inside"
@@ -95,6 +94,7 @@ BEZIER_CONTROL_POINT SubDToBezierHS(InputPatch<VS_CONTROL_POINT_OUTPUT, MAX_POIN
 // OpMemberName %HS_CONSTANT_DATA_OUTPUT 4 "vTanUCorner"
 // OpMemberName %HS_CONSTANT_DATA_OUTPUT 5 "vTanVCorner"
 // OpMemberName %HS_CONSTANT_DATA_OUTPUT 6 "vCWts"
+// OpName %SubDToBezierConstantsHS "SubDToBezierConstantsHS"
 // OpName %out_var_TANGENT "out.var.TANGENT"
 // OpName %out_var_TEXCOORD "out.var.TEXCOORD"
 // OpName %out_var_TANUCORNER "out.var.TANUCORNER"

+ 168 - 0
tools/clang/test/CodeGenSPIRV/spirv.legal.sbuffer.counter.hlsl

@@ -0,0 +1,168 @@
+// Run: %dxc -T ps_6_0 -E main
+
+struct S1 {
+    float4 f;
+};
+
+struct S2 {
+    float3 f;
+};
+
+struct S3 {
+    float2 f;
+};
+
+RWStructuredBuffer<S1>      selectRWSBuffer(RWStructuredBuffer<S1>    paramRWSBuffer, bool selector);
+AppendStructuredBuffer<S2>  selectASBuffer(AppendStructuredBuffer<S2>  paramASBuffer,  bool selector);
+ConsumeStructuredBuffer<S3> selectCSBuffer(ConsumeStructuredBuffer<S3> paramCSBuffer,  bool selector);
+
+// CHECK: %counter_var_globalRWSBuffer = OpVariable %_ptr_Uniform_type_ACSBuffer_counter Uniform
+RWStructuredBuffer<S1>      globalRWSBuffer;
+// CHECK: %counter_var_globalASBuffer = OpVariable %_ptr_Uniform_type_ACSBuffer_counter Uniform
+AppendStructuredBuffer<S2>  globalASBuffer;
+// CHECK: %counter_var_globalCSBuffer = OpVariable %_ptr_Uniform_type_ACSBuffer_counter Uniform
+ConsumeStructuredBuffer<S3> globalCSBuffer;
+
+// Counter variables for global static variables have an extra level of pointer.
+// CHECK: %counter_var_staticgRWSBuffer = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+static RWStructuredBuffer<S1>      staticgRWSBuffer = globalRWSBuffer;
+// CHECK: %counter_var_staticgASBuffer = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+static AppendStructuredBuffer<S2>  staticgASBuffer  = globalASBuffer;
+// CHECK: %counter_var_staticgCSBuffer = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+static ConsumeStructuredBuffer<S3> staticgCSBuffer  = globalCSBuffer;
+
+// Counter variables for function returns, function parameters, and local variables have an extra level of pointer.
+// CHECK:      %counter_var_paramRWSBuffer = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK-NEXT: %counter_var_selectRWSBuffer = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK-NEXT: %counter_var_localRWSBufferMain = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK-NEXT: %counter_var_paramCSBuffer = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK-NEXT: %counter_var_selectCSBuffer = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK-NEXT: %counter_var_localASBufferMain = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK-NEXT: %counter_var_paramASBuffer = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK-NEXT: %counter_var_selectASBuffer = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK-NEXT: %counter_var_localRWSBuffer = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK-NEXT: %counter_var_localCSBuffer = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK-NEXT: %counter_var_localASBuffer = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+
+// Counter variables for global static variables are initialized.
+// CHECK: %main = OpFunction
+// CHECK: OpStore %counter_var_staticgRWSBuffer %counter_var_globalRWSBuffer
+// CHECK: OpStore %counter_var_staticgASBuffer %counter_var_globalASBuffer
+// CHECK: OpStore %counter_var_staticgCSBuffer %counter_var_globalCSBuffer
+
+// CHECK: %src_main = OpFunction
+float4 main() : SV_Target {
+// Update the counter variable associated with the parameter
+// CHECK:      [[ptr:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_staticgRWSBuffer
+// CHECK-NEXT:                OpStore %counter_var_paramRWSBuffer [[ptr]]
+    selectRWSBuffer(staticgRWSBuffer, true)
+// Use the counter variable associated with the function
+// CHECK:      [[ptr:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_selectRWSBuffer
+// CHECK-NEXT:     {{%\d+}} = OpAccessChain %_ptr_Uniform_int [[ptr]] %uint_0
+        .IncrementCounter();
+
+// Update the counter variable associated with the parameter
+// CHECK:                     OpStore %counter_var_paramRWSBuffer %counter_var_globalRWSBuffer
+// Update the counter variable associated with the lhs of the assignment
+// CHECK:      [[ptr:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_selectRWSBuffer
+// CHECK-NEXT:                OpStore %counter_var_localRWSBufferMain [[ptr]]
+    RWStructuredBuffer<S1> localRWSBufferMain = selectRWSBuffer(globalRWSBuffer, true);
+
+// Use the counter variable associated with the local variable
+// CHECK:      [[ptr:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localRWSBufferMain
+// CHECK-NEXT:     {{%\d+}} = OpAccessChain %_ptr_Uniform_int [[ptr]] %uint_0
+    localRWSBufferMain.DecrementCounter();
+
+// Update the counter variable associated with the parameter
+// CHECK:                      OpStore %counter_var_paramCSBuffer %counter_var_globalCSBuffer
+// CHECK:      [[call:%\d+]] = OpFunctionCall %_ptr_Uniform_type_ConsumeStructuredBuffer_S3 %selectCSBuffer
+    S3 val3 = selectCSBuffer(globalCSBuffer, true)
+// CHECK-NEXT: [[ptr1:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_selectCSBuffer
+// CHECK-NEXT: [[ptr2:%\d+]] = OpAccessChain %_ptr_Uniform_int [[ptr1]] %uint_0
+// CHECK-NEXT: [[prev:%\d+]] = OpAtomicISub %int [[ptr2]] %uint_1 %uint_0 %int_1
+// CHECK-NEXT:  [[idx:%\d+]] = OpISub %int [[prev]] %int_1
+// CHECK-NEXT: [[ptr3:%\d+]] = OpAccessChain %_ptr_Uniform_S3 [[call]] %uint_0 [[idx]]
+// CHECK-NEXT:  [[val:%\d+]] = OpLoad %S3 [[ptr3]]
+// CHECK-NEXT:  [[vec:%\d+]] = OpCompositeExtract %v2float [[val]] 0
+// CHECK-NEXT: [[ptr4:%\d+]] = OpAccessChain %_ptr_Function_v2float %val3 %uint_0
+// CHECK-NEXT:                 OpStore [[ptr4]] [[vec]]
+        .Consume();
+
+    float3 vec = float3(val3.f, 1.0);
+    S2 val2 = {vec};
+
+// Update the counter variable associated with the parameter
+// CHECK:      [[ptr:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_staticgASBuffer
+// CHECK-NEXT:                OpStore %counter_var_paramASBuffer [[ptr]]
+
+// CHECK:     [[call:%\d+]] = OpFunctionCall %_ptr_Uniform_type_AppendStructuredBuffer_S2 %selectASBuffer %param_var_paramASBuffer %param_var_selector_2
+// CHECK-NEXT:                OpStore %localASBufferMain [[call]]
+    AppendStructuredBuffer<S2> localASBufferMain = selectASBuffer(staticgASBuffer, false);
+// Use the counter variable associated with the local variable
+// CHECK-NEXT: [[ptr:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_selectASBuffer
+// CHECK-NEXT:                OpStore %counter_var_localASBufferMain [[ptr]]
+
+// CHECK-NEXT: [[ptr1:%\d+]] = OpLoad %_ptr_Uniform_type_AppendStructuredBuffer_S2 %localASBufferMain
+// CHECK-NEXT: [[ptr2:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localASBufferMain
+// CHECK-NEXT: [[ptr3:%\d+]] = OpAccessChain %_ptr_Uniform_int [[ptr2]] %uint_0
+// CHECK-NEXT:  [[idx:%\d+]] = OpAtomicIAdd %int [[ptr3]] %uint_1 %uint_0 %int_1
+// CHECK-NEXT: [[ptr4:%\d+]] = OpAccessChain %_ptr_Uniform_S2 [[ptr1]] %uint_0 [[idx]]
+// CHECK-NEXT:  [[val:%\d+]] = OpLoad %S2_0 %val2
+// CHECK-NEXT:  [[vec:%\d+]] = OpCompositeExtract %v3float [[val]] 0
+// CHECK-NEXT: [[ptr5:%\d+]] = OpAccessChain %_ptr_Uniform_v3float [[ptr4]] %uint_0
+// CHECK-NEXT:                 OpStore [[ptr5]] [[vec]]
+    localASBufferMain.Append(val2);
+
+    return float4(val2, 2.0);
+}
+
+RWStructuredBuffer<S1>      selectRWSBuffer(RWStructuredBuffer<S1>    paramRWSBuffer, bool selector) {
+// CHECK: OpStore %counter_var_localRWSBuffer %counter_var_globalRWSBuffer
+    RWStructuredBuffer<S1>      localRWSBuffer = globalRWSBuffer;
+    if (selector)
+// Use the counter variable associated with the function
+// CHECK:      [[ptr:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_paramRWSBuffer
+// CHECK-NEXT:                OpStore %counter_var_selectRWSBuffer [[ptr]]
+// CHECK:                     OpReturnValue
+        return paramRWSBuffer;
+    else
+// Use the counter variable associated with the function
+// CHECK:      [[ptr:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localRWSBuffer
+// CHECK-NEXT:                OpStore %counter_var_selectRWSBuffer [[ptr]]
+// CHECK:                     OpReturnValue
+        return localRWSBuffer;
+}
+
+ConsumeStructuredBuffer<S3> selectCSBuffer(ConsumeStructuredBuffer<S3> paramCSBuffer,  bool selector) {
+// CHECK: OpStore %counter_var_localCSBuffer %counter_var_globalCSBuffer
+    ConsumeStructuredBuffer<S3> localCSBuffer  = globalCSBuffer;
+    if (selector)
+// Use the counter variable associated with the function
+// CHECK:      [[ptr:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_paramCSBuffer
+// CHECK-NEXT:                OpStore %counter_var_selectCSBuffer [[ptr]]
+// CHECK:                     OpReturnValue
+        return paramCSBuffer;
+    else
+// Use the counter variable associated with the function
+// CHECK:      [[ptr:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localCSBuffer
+// CHECK-NEXT:                OpStore %counter_var_selectCSBuffer [[ptr]]
+// CHECK:                     OpReturnValue
+        return localCSBuffer;
+}
+
+AppendStructuredBuffer<S2>  selectASBuffer(AppendStructuredBuffer<S2>  paramASBuffer,  bool selector) {
+// CHECK: OpStore %counter_var_localASBuffer %counter_var_globalASBuffer
+    AppendStructuredBuffer<S2>  localASBuffer  = globalASBuffer;
+    if (selector)
+// Use the counter variable associated with the function
+// CHECK:      [[ptr:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_paramASBuffer
+// CHECK-NEXT:                OpStore %counter_var_selectASBuffer [[ptr]]
+// CHECK:                     OpReturnValue
+        return paramASBuffer;
+    else
+// Use the counter variable associated with the function
+// CHECK:      [[ptr:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localASBuffer
+// CHECK-NEXT:                OpStore %counter_var_selectASBuffer [[ptr]]
+// CHECK:                     OpReturnValue
+        return localASBuffer;
+}

+ 0 - 10
tools/clang/test/CodeGenSPIRV/spirv.legal.sbuffer.methods.hlsl

@@ -68,24 +68,14 @@ float4 main() : SV_Target {
 // CHECK-NEXT: [[ptr2:%\d+]] = OpAccessChain %_ptr_Uniform_v4float [[ptr1]] %int_0 %uint_4 %int_0
 // CHECK-NEXT:                 OpStore [[ptr2]] {{%\d+}}
     localRWSBuffer[4].f = 42.;
-    // TODO
-    counter = localRWSBuffer.IncrementCounter();
-    // TODO
-    counter = localRWSBuffer.DecrementCounter();
 
 // CHECK:      [[ptr:%\d+]] = OpLoad %_ptr_Uniform_type_AppendStructuredBuffer_T1 %localASBuffer
 // CHECK-NEXT:     {{%\d+}} = OpArrayLength %uint [[ptr]] 0
     localASBuffer.GetDimensions(numStructs, stride);
-    // TODO
-    counter = localRWSBuffer.DecrementCounter();
-    localASBuffer.Append(t1);
 
 // CHECK:      [[ptr:%\d+]] = OpLoad %_ptr_Uniform_type_ConsumeStructuredBuffer_T2 %localCSBuffer
 // CHECK-NEXT:     {{%\d+}} = OpArrayLength %uint [[ptr]] 0
     localCSBuffer.GetDimensions(numStructs, stride);
-    // TODO
-    counter = localRWSBuffer.DecrementCounter();
-    t2 = localCSBuffer.Consume();
 
     uint  byte;
     uint2 byte2;

+ 5 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -978,6 +978,11 @@ TEST_F(FileTest, SpirvLegalizationStructuredBufferMethods) {
               // The generated SPIR-V needs legalization.
               /*runValidation=*/false);
 }
+TEST_F(FileTest, SpirvLegalizationStructuredBufferCounter) {
+  runFileTest("spirv.legal.sbuffer.counter.hlsl", Expect::Success,
+              // The generated SPIR-V needs legalization.
+              /*runValidation=*/false);
+}
 TEST_F(FileTest, SpirvLegalizationConstantBuffer) {
   runFileTest("spirv.legal.cbuffer.hlsl", Expect::Success,
               // The generated SPIR-V needs legalization.