Bladeren bron

[spirv] Legalization: associated counters in implicit objects (#1002)

An implicit object is translated into the first argument to the
method call. If the implicit object is of struct type and its
fields have associated counters, we need to assign its associated
counters accordingly like other normal arguments.

For each such method, we generate associated counters for its
implicit object. At a call site, we assign the associated counters
from the real object to the ones associated with the method implict
object.

Also refreshed SPIRV-Tools
Lei Zhang 7 jaren geleden
bovenliggende
commit
4b2cbf2167

+ 1 - 1
external/SPIRV-Tools

@@ -1 +1 @@
-Subproject commit d54a286c754230a67ffdb23850253596800ed0bd
+Subproject commit 6cc772c3ce358193aac132a2c798e79e21aec0ad

+ 6 - 3
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -595,8 +595,10 @@ uint32_t DeclResultIdMapper::getOrRegisterFnResultId(const FunctionDecl *fn) {
 }
 }
 
 
 const CounterIdAliasPair *DeclResultIdMapper::getCounterIdAliasPair(
 const CounterIdAliasPair *DeclResultIdMapper::getCounterIdAliasPair(
-    const DeclaratorDecl *decl,
-    const llvm::SmallVector<uint32_t, 4> *indices) const {
+    const DeclaratorDecl *decl, const llvm::SmallVector<uint32_t, 4> *indices) {
+  if (!decl)
+    return nullptr;
+
   if (indices) {
   if (indices) {
     // Indices are provided. Walk through the fields of the decl.
     // Indices are provided. Walk through the fields of the decl.
     const auto counter = fieldCounterVars.find(decl);
     const auto counter = fieldCounterVars.find(decl);
@@ -608,11 +610,12 @@ const CounterIdAliasPair *DeclResultIdMapper::getCounterIdAliasPair(
     if (counter != counterVars.end())
     if (counter != counterVars.end())
       return &counter->second;
       return &counter->second;
   }
   }
+
   return nullptr;
   return nullptr;
 }
 }
 
 
 const CounterVarFields *
 const CounterVarFields *
-DeclResultIdMapper::getCounterVarFields(const DeclaratorDecl *decl) const {
+DeclResultIdMapper::getCounterVarFields(const DeclaratorDecl *decl) {
   if (!decl)
   if (!decl)
     return nullptr;
     return nullptr;
 
 

+ 7 - 6
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -283,10 +283,11 @@ public:
   uint32_t createFnParam(const ParmVarDecl *param);
   uint32_t createFnParam(const ParmVarDecl *param);
 
 
   /// \brief Creates the counter variable associated with the given param.
   /// \brief Creates the counter variable associated with the given param.
-  /// This is meant to be used for forward-declared functions.
+  /// This is meant to be used for forward-declared functions and this objects
+  /// of methods.
   ///
   ///
   /// Note: legalization specific code
   /// Note: legalization specific code
-  inline void createFnParamCounterVar(const ParmVarDecl *param);
+  inline void createFnParamCounterVar(const VarDecl *param);
 
 
   /// \brief Creates a function-scope variable in the current function and
   /// \brief Creates a function-scope variable in the current function and
   /// returns its <result-id>.
   /// returns its <result-id>.
@@ -383,12 +384,12 @@ public:
   /// if the given decl has no associated counter variable created.
   /// if the given decl has no associated counter variable created.
   const CounterIdAliasPair *getCounterIdAliasPair(
   const CounterIdAliasPair *getCounterIdAliasPair(
       const DeclaratorDecl *decl,
       const DeclaratorDecl *decl,
-      const llvm::SmallVector<uint32_t, 4> *indices = nullptr) const;
+      const llvm::SmallVector<uint32_t, 4> *indices = nullptr);
 
 
   /// \brief Returns all the associated counters for the given decl. The decl is
   /// \brief Returns all the associated counters for the given decl. The decl is
   /// expected to be a struct containing alias RW/Append/Consume structured
   /// expected to be a struct containing alias RW/Append/Consume structured
   /// buffers. Returns nullptr if it does not.
   /// buffers. Returns nullptr if it does not.
-  const CounterVarFields *getCounterVarFields(const DeclaratorDecl *decl) const;
+  const CounterVarFields *getCounterVarFields(const DeclaratorDecl *decl);
 
 
   /// \brief Returns the <type-id> for the given cbuffer, tbuffer,
   /// \brief Returns the <type-id> for the given cbuffer, tbuffer,
   /// ConstantBuffer, TextureBuffer, or push constant block.
   /// ConstantBuffer, TextureBuffer, or push constant block.
@@ -717,8 +718,8 @@ bool DeclResultIdMapper::isInputStorageClass(const StageVar &v) {
          spv::StorageClass::Input;
          spv::StorageClass::Input;
 }
 }
 
 
-void DeclResultIdMapper::createFnParamCounterVar(const ParmVarDecl *param) {
-  return createCounterVarForDecl(param);
+void DeclResultIdMapper::createFnParamCounterVar(const VarDecl *param) {
+  createCounterVarForDecl(param);
 }
 }
 
 
 void DeclResultIdMapper::createFieldCounterVars(const DeclaratorDecl *decl) {
 void DeclResultIdMapper::createFieldCounterVars(const DeclaratorDecl *decl) {

+ 50 - 15
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -1652,14 +1652,18 @@ SpirvEvalInfo SPIRVEmitter::processCall(const CallExpr *callExpr) {
   llvm::SmallVector<SpirvEvalInfo, 4> args; // Evaluated arguments
   llvm::SmallVector<SpirvEvalInfo, 4> args; // Evaluated arguments
 
 
   if (const auto *memberCall = dyn_cast<CXXMemberCallExpr>(callExpr)) {
   if (const auto *memberCall = dyn_cast<CXXMemberCallExpr>(callExpr)) {
-    isNonStaticMemberCall =
-        !cast<CXXMethodDecl>(memberCall->getCalleeDecl())->isStatic();
+    const auto *memberFn = cast<CXXMethodDecl>(memberCall->getCalleeDecl());
+    isNonStaticMemberCall = !memberFn->isStatic();
+
     if (isNonStaticMemberCall) {
     if (isNonStaticMemberCall) {
       // For non-static member calls, evaluate the object and pass it as the
       // For non-static member calls, evaluate the object and pass it as the
       // first argument.
       // first argument.
       const auto *object = memberCall->getImplicitObjectArgument();
       const auto *object = memberCall->getImplicitObjectArgument();
       object = object->IgnoreParenNoopCasts(astContext);
       object = object->IgnoreParenNoopCasts(astContext);
 
 
+      // Update counter variable associated with the implicit object
+      tryToAssignCounterVar(getOrCreateDeclForMethodObject(memberFn), object);
+
       objectType = object->getType();
       objectType = object->getType();
       objectEvalInfo = doExpr(object);
       objectEvalInfo = doExpr(object);
       uint32_t objectId = objectEvalInfo;
       uint32_t objectId = objectEvalInfo;
@@ -2817,6 +2821,9 @@ bool SPIRVEmitter::tryToAssignCounterVar(const DeclaratorDecl *dstDecl,
   // the translation of the real definition may not be started yet.
   // the translation of the real definition may not be started yet.
   if (const auto *param = dyn_cast<ParmVarDecl>(dstDecl))
   if (const auto *param = dyn_cast<ParmVarDecl>(dstDecl))
     declIdMapper.createFnParamCounterVar(param);
     declIdMapper.createFnParamCounterVar(param);
+  // For implicit objects of methods. Similar to the above.
+  else if (const auto *thisObject = dyn_cast<ImplicitParamDecl>(dstDecl))
+    declIdMapper.createFnParamCounterVar(thisObject);
 
 
   // Handle AssocCounter#1 (see CounterVarFields comment)
   // Handle AssocCounter#1 (see CounterVarFields comment)
   if (const auto *dstPair = declIdMapper.getCounterIdAliasPair(dstDecl)) {
   if (const auto *dstPair = declIdMapper.getCounterIdAliasPair(dstDecl)) {
@@ -2831,11 +2838,9 @@ bool SPIRVEmitter::tryToAssignCounterVar(const DeclaratorDecl *dstDecl,
   }
   }
 
 
   // Handle AssocCounter#3
   // Handle AssocCounter#3
-  const auto *dstFields = declIdMapper.getCounterVarFields(dstDecl);
   llvm::SmallVector<uint32_t, 4> srcIndices;
   llvm::SmallVector<uint32_t, 4> srcIndices;
-  const auto *srcDecl = getReferencedDef(
-      collectArrayStructIndices(srcExpr, &srcIndices, /*rawIndex=*/true));
-  const auto *srcFields = declIdMapper.getCounterVarFields(srcDecl);
+  const auto *dstFields = declIdMapper.getCounterVarFields(dstDecl);
+  const auto *srcFields = getIntermediateACSBufferCounter(srcExpr, &srcIndices);
 
 
   if (dstFields && srcFields) {
   if (dstFields && srcFields) {
     if (!dstFields->assign(*srcFields, theBuilder, typeTranslator)) {
     if (!dstFields->assign(*srcFields, theBuilder, typeTranslator)) {
@@ -2875,12 +2880,8 @@ bool SPIRVEmitter::tryToAssignCounterVar(const Expr *dstExpr,
   // Handle AssocCounter#3 & AssocCounter#4
   // Handle AssocCounter#3 & AssocCounter#4
   llvm::SmallVector<uint32_t, 4> dstIndices;
   llvm::SmallVector<uint32_t, 4> dstIndices;
   llvm::SmallVector<uint32_t, 4> srcIndices;
   llvm::SmallVector<uint32_t, 4> srcIndices;
-  const auto *dstDecl = getReferencedDef(
-      collectArrayStructIndices(dstExpr, &dstIndices, /*rawIndex=*/true));
-  const auto *srcDecl = getReferencedDef(
-      collectArrayStructIndices(srcExpr, &srcIndices, /*rawIndex=*/true));
-  const auto *dstFields = declIdMapper.getCounterVarFields(dstDecl);
-  const auto *srcFields = declIdMapper.getCounterVarFields(srcDecl);
+  const auto *srcFields = getIntermediateACSBufferCounter(srcExpr, &srcIndices);
+  const auto *dstFields = getIntermediateACSBufferCounter(dstExpr, &dstIndices);
 
 
   if (dstFields && srcFields) {
   if (dstFields && srcFields) {
     return dstFields->assign(*srcFields, dstIndices, srcIndices, theBuilder,
     return dstFields->assign(*srcFields, dstIndices, srcIndices, theBuilder,
@@ -2898,13 +2899,47 @@ SPIRVEmitter::getFinalACSBufferCounter(const Expr *expr) {
 
 
   // AssocCounter#2: referencing some non-struct field
   // AssocCounter#2: referencing some non-struct field
   llvm::SmallVector<uint32_t, 4> indices;
   llvm::SmallVector<uint32_t, 4> indices;
-  if (const auto *decl = getReferencedDef(
-          collectArrayStructIndices(expr, &indices, /*rawIndex=*/true)))
-    return declIdMapper.getCounterIdAliasPair(decl, &indices);
+
+  const auto *base =
+      collectArrayStructIndices(expr, &indices, /*rawIndex=*/true);
+  const auto *decl =
+      (base && isa<CXXThisExpr>(base))
+          ? getOrCreateDeclForMethodObject(cast<CXXMethodDecl>(curFunction))
+          : getReferencedDef(base);
+  return declIdMapper.getCounterIdAliasPair(decl, &indices);
 
 
   return nullptr;
   return nullptr;
 }
 }
 
 
+const CounterVarFields *SPIRVEmitter::getIntermediateACSBufferCounter(
+    const Expr *expr, llvm::SmallVector<uint32_t, 4> *indices) {
+  const auto *base =
+      collectArrayStructIndices(expr, indices, /*rawIndex=*/true);
+  const auto *decl =
+      (base && isa<CXXThisExpr>(base))
+          // Use the decl we created to represent the implicit object
+          ? getOrCreateDeclForMethodObject(cast<CXXMethodDecl>(curFunction))
+          // Find the referenced decl from the original source code
+          : getReferencedDef(base);
+
+  return declIdMapper.getCounterVarFields(decl);
+}
+
+const ImplicitParamDecl *
+SPIRVEmitter::getOrCreateDeclForMethodObject(const CXXMethodDecl *method) {
+  const auto found = thisDecls.find(method);
+  if (found != thisDecls.end())
+    return found->second;
+
+  const std::string name = method->getName().str() + ".this";
+  // Create a new identifier to convey the name
+  auto &identifier = astContext.Idents.get(name);
+
+  return thisDecls[method] = ImplicitParamDecl::Create(
+             astContext, /*DC=*/nullptr, SourceLocation(), &identifier,
+             method->getThisType(astContext)->getPointeeType());
+}
+
 SpirvEvalInfo
 SpirvEvalInfo
 SPIRVEmitter::processACSBufferAppendConsume(const CXXMemberCallExpr *expr) {
 SPIRVEmitter::processACSBufferAppendConsume(const CXXMemberCallExpr *expr) {
   const bool isAppend = expr->getNumArgs() == 1;
   const bool isAppend = expr->getNumArgs() == 1;

+ 21 - 1
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -698,7 +698,16 @@ private:
   ///
   ///
   /// This method only handles final alias structured buffers, which means
   /// This method only handles final alias structured buffers, which means
   /// AssocCounter#1 and AssocCounter#2.
   /// AssocCounter#1 and AssocCounter#2.
-  const CounterIdAliasPair *getFinalACSBufferCounter(const Expr *decl);
+  const CounterIdAliasPair *getFinalACSBufferCounter(const Expr *expr);
+  /// This method handles AssocCounter#3 and AssocCounter#4.
+  const CounterVarFields *
+  getIntermediateACSBufferCounter(const Expr *expr,
+                                  llvm::SmallVector<uint32_t, 4> *indices);
+
+  /// Gets or creates an ImplicitParamDecl to represent the implicit object
+  /// parameter of the given method.
+  const ImplicitParamDecl *
+  getOrCreateDeclForMethodObject(const CXXMethodDecl *method);
 
 
   /// \brief Loads numWords 32-bit unsigned integers or stores numWords 32-bit
   /// \brief Loads numWords 32-bit unsigned integers or stores numWords 32-bit
   /// unsigned integers (based on the doStore parameter) to the given
   /// unsigned integers (based on the doStore parameter) to the given
@@ -843,6 +852,17 @@ private:
   /// Note: legalization specific code
   /// Note: legalization specific code
   bool needsLegalization;
   bool needsLegalization;
 
 
+  /// Mapping from methods to the decls to represent their implicit object
+  /// parameters
+  ///
+  /// We need this map because that we need to update the associated counters on
+  /// the implicit object when invoking method calls. The ImplicitParamDecl
+  /// mapped to serves as a key to find the associated counters in
+  /// DeclResultIdMapper.
+  ///
+  /// Note: legalization specific code
+  llvm::DenseMap<const CXXMethodDecl *, const ImplicitParamDecl *> thisDecls;
+
   /// Global variables that should be initialized once at the begining of the
   /// Global variables that should be initialized once at the begining of the
   /// entry function.
   /// entry function.
   llvm::SmallVector<const VarDecl *, 4> toInitGloalVars;
   llvm::SmallVector<const VarDecl *, 4> toInitGloalVars;

+ 114 - 0
tools/clang/test/CodeGenSPIRV/spirv.legal.sbuffer.counter.method.hlsl

@@ -0,0 +1,114 @@
+// Run: %dxc -T vs_6_0 -E main
+
+struct Bundle {
+      RWStructuredBuffer<float> rw;
+  AppendStructuredBuffer<float> append;
+ ConsumeStructuredBuffer<float> consume;
+};
+
+// Counter variables for the this object of getCSBuffer()
+// CHECK: %counter_var_getCSBuffer_this_0_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_getCSBuffer_this_0_1 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_getCSBuffer_this_0_2 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_getCSBuffer_this_1_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_getCSBuffer_this_1_1 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_getCSBuffer_this_1_2 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+
+// Counter variables for the this object of getASBuffer()
+// CHECK: %counter_var_getASBuffer_this_0_0_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_getASBuffer_this_0_0_1 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_getASBuffer_this_0_0_2 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_getASBuffer_this_0_1_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_getASBuffer_this_0_1_1 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_getASBuffer_this_0_1_2 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+
+// Counter variables for the this object of getRWSBuffer()
+// CHECK: %counter_var_getRWSBuffer_this_0_0_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_getRWSBuffer_this_0_0_1 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_getRWSBuffer_this_0_0_2 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_getRWSBuffer_this_0_1_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_getRWSBuffer_this_0_1_1 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_getRWSBuffer_this_0_1_2 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+
+struct TwoBundle {
+    Bundle b1;
+    Bundle b2;
+
+    // Checks at the end of the file
+    ConsumeStructuredBuffer<float> getCSBuffer() { return b1.consume; }
+};
+
+struct Wrapper {
+    TwoBundle b;
+
+    // Checks at the end of the file
+    AppendStructuredBuffer<float> getASBuffer() { return b.b1.append; }
+
+    // Checks at the end of the file
+    RWStructuredBuffer<float> getRWSBuffer() { return b.b2.rw; }
+};
+
+// CHECK-LABLE: %src_main = OpFunction
+float main() : VVV {
+    TwoBundle localBundle;
+    Wrapper   localWrapper;
+
+// CHECK:      [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localBundle_0_0
+// CHECK-NEXT:                    OpStore %counter_var_getCSBuffer_this_0_0 [[counter]]
+// CHECK-NEXT: [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localBundle_0_1
+// CHECK-NEXT:                    OpStore %counter_var_getCSBuffer_this_0_1 [[counter]]
+// CHECK-NEXT: [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localBundle_0_2
+// CHECK-NEXT:                    OpStore %counter_var_getCSBuffer_this_0_2 [[counter]]
+// CHECK-NEXT: [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localBundle_1_0
+// CHECK-NEXT:                    OpStore %counter_var_getCSBuffer_this_1_0 [[counter]]
+// CHECK-NEXT: [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localBundle_1_1
+// CHECK-NEXT:                    OpStore %counter_var_getCSBuffer_this_1_1 [[counter]]
+// CHECK-NEXT: [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localBundle_1_2
+// CHECK-NEXT:                    OpStore %counter_var_getCSBuffer_this_1_2 [[counter]]
+// CHECK-NEXT:                    OpFunctionCall %_ptr_Uniform_type_RWStructuredBuffer_float %TwoBundle_getCSBuffer %localBundle
+    float value = localBundle.getCSBuffer().Consume();
+
+// CHECK:      [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localWrapper_0_0_0
+// CHECK-NEXT:                    OpStore %counter_var_getASBuffer_this_0_0_0 [[counter]]
+// CHECK-NEXT: [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localWrapper_0_0_1
+// CHECK-NEXT:                    OpStore %counter_var_getASBuffer_this_0_0_1 [[counter]]
+// CHECK-NEXT: [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localWrapper_0_0_2
+// CHECK-NEXT:                    OpStore %counter_var_getASBuffer_this_0_0_2 [[counter]]
+// CHECK-NEXT: [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localWrapper_0_1_0
+// CHECK-NEXT:                    OpStore %counter_var_getASBuffer_this_0_1_0 [[counter]]
+// CHECK-NEXT: [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localWrapper_0_1_1
+// CHECK-NEXT:                    OpStore %counter_var_getASBuffer_this_0_1_1 [[counter]]
+// CHECK-NEXT: [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localWrapper_0_1_2
+// CHECK-NEXT:                    OpStore %counter_var_getASBuffer_this_0_1_2 [[counter]]
+// CHECK-NEXT:                    OpFunctionCall %_ptr_Uniform_type_RWStructuredBuffer_float %Wrapper_getASBuffer %localWrapper
+    localWrapper.getASBuffer().Append(4.2);
+
+// CHECK:      [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localWrapper_0_0_0
+// CHECK-NEXT:                    OpStore %counter_var_getRWSBuffer_this_0_0_0 [[counter]]
+// CHECK-NEXT: [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localWrapper_0_0_1
+// CHECK-NEXT:                    OpStore %counter_var_getRWSBuffer_this_0_0_1 [[counter]]
+// CHECK-NEXT: [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localWrapper_0_0_2
+// CHECK-NEXT:                    OpStore %counter_var_getRWSBuffer_this_0_0_2 [[counter]]
+// CHECK-NEXT: [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localWrapper_0_1_0
+// CHECK-NEXT:                    OpStore %counter_var_getRWSBuffer_this_0_1_0 [[counter]]
+// CHECK-NEXT: [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localWrapper_0_1_1
+// CHECK-NEXT:                    OpStore %counter_var_getRWSBuffer_this_0_1_1 [[counter]]
+// CHECK-NEXT: [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localWrapper_0_1_2
+// CHECK-NEXT:                    OpStore %counter_var_getRWSBuffer_this_0_1_2 [[counter]]
+// CHECK-NEXT:                    OpFunctionCall %_ptr_Uniform_type_RWStructuredBuffer_float %Wrapper_getRWSBuffer %localWrapper
+    RWStructuredBuffer<float> localRWSBuffer = localWrapper.getRWSBuffer();
+
+    return localRWSBuffer[5];
+}
+
+// CHECK-LABEL: %TwoBundle_getCSBuffer = OpFunction
+// CHECK:             [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_getCSBuffer_this_0_2
+// CHECK-NEXT:                           OpStore %counter_var_getCSBuffer [[counter]]
+
+// CHECK-LABEL: %Wrapper_getASBuffer = OpFunction
+// CHECK:           [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_getASBuffer_this_0_0_1
+// CHECK-NEXT:                         OpStore %counter_var_getASBuffer [[counter]]
+
+// CHECK-LABEL: %Wrapper_getRWSBuffer = OpFunction
+// CHECK:            [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_getRWSBuffer_this_0_1_0
+// CHECK-NEXT:                          OpStore %counter_var_getRWSBuffer [[counter]]

+ 18 - 12
tools/clang/test/CodeGenSPIRV/spirv.legal.sbuffer.struct.hlsl

@@ -11,24 +11,30 @@ struct S {
     ConsumeStructuredBuffer<float4> consume;
     ConsumeStructuredBuffer<float4> consume;
 };
 };
 
 
-// CHECK: %T = OpTypeStruct %_ptr_Uniform_type_StructuredBuffer_Basic %_ptr_Uniform_type_RWStructuredBuffer_Basic %_ptr_Uniform_type_StructuredBuffer_int
+// CHECK: %T = OpTypeStruct %_ptr_Uniform_type_StructuredBuffer_Basic %_ptr_Uniform_type_RWStructuredBuffer_Basic
 struct T {
 struct T {
       StructuredBuffer<Basic> ro;
       StructuredBuffer<Basic> ro;
     RWStructuredBuffer<Basic> rw;
     RWStructuredBuffer<Basic> rw;
-      StructuredBuffer<int>   ro2;
 
 
-    StructuredBuffer<Basic> getSBuffer() { return ro; }
-    StructuredBuffer<int>   getSBuffer2() { return ro2; }
 };
 };
 
 
-// CHECK: %Combine = OpTypeStruct %S %T %_ptr_Uniform_type_ByteAddressBuffer %_ptr_Uniform_type_RWByteAddressBuffer
+struct U {
+    StructuredBuffer<Basic> basic;
+    StructuredBuffer<int>   integer;
+
+    StructuredBuffer<Basic> getSBufferStruct() { return basic;   }
+    StructuredBuffer<int>   getSBufferInt()    { return integer; }
+};
+
+// CHECK: %Combine = OpTypeStruct %S %T %_ptr_Uniform_type_ByteAddressBuffer %_ptr_Uniform_type_RWByteAddressBuffer %U
 struct Combine {
 struct Combine {
                       S s;
                       S s;
                       T t;
                       T t;
       ByteAddressBuffer ro;
       ByteAddressBuffer ro;
     RWByteAddressBuffer rw;
     RWByteAddressBuffer rw;
+                      U u;
 
 
-    T getT() { return t; }
+    U getU() { return u; }
 };
 };
 
 
        StructuredBuffer<Basic>  gSBuffer;
        StructuredBuffer<Basic>  gSBuffer;
@@ -71,17 +77,17 @@ float4 main() : SV_Target {
 
 
 // Make sure that we create temporary variable for intermediate objects since
 // Make sure that we create temporary variable for intermediate objects since
 // the function expect pointers as parameters.
 // the function expect pointers as parameters.
-// CHECK:     [[call1:%\d+]] = OpFunctionCall %T %Combine_getT %c
-// CHECK-NEXT:                 OpStore %temp_var_T [[call1]]
-// CHECK-NEXT:[[call2:%\d+]] = OpFunctionCall %_ptr_Uniform_type_StructuredBuffer_Basic %T_getSBuffer %temp_var_T
+// CHECK:     [[call1:%\d+]] = OpFunctionCall %U %Combine_getU %c
+// CHECK-NEXT:                 OpStore %temp_var_U [[call1]]
+// CHECK-NEXT:[[call2:%\d+]] = OpFunctionCall %_ptr_Uniform_type_StructuredBuffer_Basic %U_getSBufferStruct %temp_var_U
 // CHECK-NEXT:  [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_v4float [[call2]] %int_0 %uint_10 %int_1
 // CHECK-NEXT:  [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_v4float [[call2]] %int_0 %uint_10 %int_1
 // CHECK-NEXT:      {{%\d+}} = OpLoad %v4float [[ptr]]
 // CHECK-NEXT:      {{%\d+}} = OpLoad %v4float [[ptr]]
-    float4 val = c.getT().getSBuffer()[10].b;
+    float4 val = c.getU().getSBufferStruct()[10].b;
 
 
 // Check StructuredBuffer of scalar type
 // Check StructuredBuffer of scalar type
-// CHECK:     [[call2:%\d+]] = OpFunctionCall %_ptr_Uniform_type_StructuredBuffer_int %T_getSBuffer2 %temp_var_T_0
+// CHECK:     [[call2:%\d+]] = OpFunctionCall %_ptr_Uniform_type_StructuredBuffer_int %U_getSBufferInt %temp_var_U_0
 // CHECK-NEXT:      {{%\d+}} = OpAccessChain %_ptr_Uniform_int [[call2]] %int_0 %uint_42
 // CHECK-NEXT:      {{%\d+}} = OpAccessChain %_ptr_Uniform_int [[call2]] %int_0 %uint_42
-    int index = c.getT().getSBuffer2()[42];
+    int index = c.getU().getSBufferInt()[42];
 
 
 // CHECK:      [[val:%\d+]] = OpLoad %Combine %c
 // CHECK:      [[val:%\d+]] = OpLoad %Combine %c
 // CHECK:                     OpStore %param_var_comb [[val]]
 // CHECK:                     OpStore %param_var_comb [[val]]

+ 6 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -1006,9 +1006,15 @@ TEST_F(FileTest, SpirvLegalizationStructuredBufferCounter) {
               /*runValidation=*/true, /*relaxLogicalPointer=*/true);
               /*runValidation=*/true, /*relaxLogicalPointer=*/true);
 }
 }
 TEST_F(FileTest, SpirvLegalizationStructuredBufferCounterInStruct) {
 TEST_F(FileTest, SpirvLegalizationStructuredBufferCounterInStruct) {
+  // Tests using struct/class having associated counters
   runFileTest("spirv.legal.sbuffer.counter.struct.hlsl", Expect::Success,
   runFileTest("spirv.legal.sbuffer.counter.struct.hlsl", Expect::Success,
               /*runValidation=*/true, /*relaxLogicalPointer=*/true);
               /*runValidation=*/true, /*relaxLogicalPointer=*/true);
 }
 }
+TEST_F(FileTest, SpirvLegalizationStructuredBufferCounterInMethod) {
+  // Tests using methods whose enclosing struct/class having associated counters
+  runFileTest("spirv.legal.sbuffer.counter.method.hlsl", Expect::Success,
+              /*runValidation=*/true, /*relaxLogicalPointer=*/true);
+}
 TEST_F(FileTest, SpirvLegalizationStructuredBufferInStruct) {
 TEST_F(FileTest, SpirvLegalizationStructuredBufferInStruct) {
   runFileTest("spirv.legal.sbuffer.struct.hlsl", Expect::Success,
   runFileTest("spirv.legal.sbuffer.struct.hlsl", Expect::Success,
               /*runValidation=*/true, /*relaxLogicalPointer=*/true);
               /*runValidation=*/true, /*relaxLogicalPointer=*/true);