소스 검색

[spirv] Legalization: associated counters in implicit objects (#1002)

An implicit object is translated into the first argument to the
method call. If the implicit object is of struct type and its
fields have associated counters, we need to assign its associated
counters accordingly like other normal arguments.

For each such method, we generate associated counters for its
implicit object. At a call site, we assign the associated counters
from the real object to the ones associated with the method implict
object.

Also refreshed SPIRV-Tools
Lei Zhang 7 년 전
부모
커밋
4b2cbf2167

+ 1 - 1
external/SPIRV-Tools

@@ -1 +1 @@
-Subproject commit d54a286c754230a67ffdb23850253596800ed0bd
+Subproject commit 6cc772c3ce358193aac132a2c798e79e21aec0ad

+ 6 - 3
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -595,8 +595,10 @@ uint32_t DeclResultIdMapper::getOrRegisterFnResultId(const FunctionDecl *fn) {
 }
 
 const CounterIdAliasPair *DeclResultIdMapper::getCounterIdAliasPair(
-    const DeclaratorDecl *decl,
-    const llvm::SmallVector<uint32_t, 4> *indices) const {
+    const DeclaratorDecl *decl, const llvm::SmallVector<uint32_t, 4> *indices) {
+  if (!decl)
+    return nullptr;
+
   if (indices) {
     // Indices are provided. Walk through the fields of the decl.
     const auto counter = fieldCounterVars.find(decl);
@@ -608,11 +610,12 @@ const CounterIdAliasPair *DeclResultIdMapper::getCounterIdAliasPair(
     if (counter != counterVars.end())
       return &counter->second;
   }
+
   return nullptr;
 }
 
 const CounterVarFields *
-DeclResultIdMapper::getCounterVarFields(const DeclaratorDecl *decl) const {
+DeclResultIdMapper::getCounterVarFields(const DeclaratorDecl *decl) {
   if (!decl)
     return nullptr;
 

+ 7 - 6
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -283,10 +283,11 @@ public:
   uint32_t createFnParam(const ParmVarDecl *param);
 
   /// \brief Creates the counter variable associated with the given param.
-  /// This is meant to be used for forward-declared functions.
+  /// This is meant to be used for forward-declared functions and this objects
+  /// of methods.
   ///
   /// Note: legalization specific code
-  inline void createFnParamCounterVar(const ParmVarDecl *param);
+  inline void createFnParamCounterVar(const VarDecl *param);
 
   /// \brief Creates a function-scope variable in the current function and
   /// returns its <result-id>.
@@ -383,12 +384,12 @@ public:
   /// if the given decl has no associated counter variable created.
   const CounterIdAliasPair *getCounterIdAliasPair(
       const DeclaratorDecl *decl,
-      const llvm::SmallVector<uint32_t, 4> *indices = nullptr) const;
+      const llvm::SmallVector<uint32_t, 4> *indices = nullptr);
 
   /// \brief Returns all the associated counters for the given decl. The decl is
   /// expected to be a struct containing alias RW/Append/Consume structured
   /// buffers. Returns nullptr if it does not.
-  const CounterVarFields *getCounterVarFields(const DeclaratorDecl *decl) const;
+  const CounterVarFields *getCounterVarFields(const DeclaratorDecl *decl);
 
   /// \brief Returns the <type-id> for the given cbuffer, tbuffer,
   /// ConstantBuffer, TextureBuffer, or push constant block.
@@ -717,8 +718,8 @@ bool DeclResultIdMapper::isInputStorageClass(const StageVar &v) {
          spv::StorageClass::Input;
 }
 
-void DeclResultIdMapper::createFnParamCounterVar(const ParmVarDecl *param) {
-  return createCounterVarForDecl(param);
+void DeclResultIdMapper::createFnParamCounterVar(const VarDecl *param) {
+  createCounterVarForDecl(param);
 }
 
 void DeclResultIdMapper::createFieldCounterVars(const DeclaratorDecl *decl) {

+ 50 - 15
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -1652,14 +1652,18 @@ SpirvEvalInfo SPIRVEmitter::processCall(const CallExpr *callExpr) {
   llvm::SmallVector<SpirvEvalInfo, 4> args; // Evaluated arguments
 
   if (const auto *memberCall = dyn_cast<CXXMemberCallExpr>(callExpr)) {
-    isNonStaticMemberCall =
-        !cast<CXXMethodDecl>(memberCall->getCalleeDecl())->isStatic();
+    const auto *memberFn = cast<CXXMethodDecl>(memberCall->getCalleeDecl());
+    isNonStaticMemberCall = !memberFn->isStatic();
+
     if (isNonStaticMemberCall) {
       // For non-static member calls, evaluate the object and pass it as the
       // first argument.
       const auto *object = memberCall->getImplicitObjectArgument();
       object = object->IgnoreParenNoopCasts(astContext);
 
+      // Update counter variable associated with the implicit object
+      tryToAssignCounterVar(getOrCreateDeclForMethodObject(memberFn), object);
+
       objectType = object->getType();
       objectEvalInfo = doExpr(object);
       uint32_t objectId = objectEvalInfo;
@@ -2817,6 +2821,9 @@ bool SPIRVEmitter::tryToAssignCounterVar(const DeclaratorDecl *dstDecl,
   // the translation of the real definition may not be started yet.
   if (const auto *param = dyn_cast<ParmVarDecl>(dstDecl))
     declIdMapper.createFnParamCounterVar(param);
+  // For implicit objects of methods. Similar to the above.
+  else if (const auto *thisObject = dyn_cast<ImplicitParamDecl>(dstDecl))
+    declIdMapper.createFnParamCounterVar(thisObject);
 
   // Handle AssocCounter#1 (see CounterVarFields comment)
   if (const auto *dstPair = declIdMapper.getCounterIdAliasPair(dstDecl)) {
@@ -2831,11 +2838,9 @@ bool SPIRVEmitter::tryToAssignCounterVar(const DeclaratorDecl *dstDecl,
   }
 
   // Handle AssocCounter#3
-  const auto *dstFields = declIdMapper.getCounterVarFields(dstDecl);
   llvm::SmallVector<uint32_t, 4> srcIndices;
-  const auto *srcDecl = getReferencedDef(
-      collectArrayStructIndices(srcExpr, &srcIndices, /*rawIndex=*/true));
-  const auto *srcFields = declIdMapper.getCounterVarFields(srcDecl);
+  const auto *dstFields = declIdMapper.getCounterVarFields(dstDecl);
+  const auto *srcFields = getIntermediateACSBufferCounter(srcExpr, &srcIndices);
 
   if (dstFields && srcFields) {
     if (!dstFields->assign(*srcFields, theBuilder, typeTranslator)) {
@@ -2875,12 +2880,8 @@ bool SPIRVEmitter::tryToAssignCounterVar(const Expr *dstExpr,
   // Handle AssocCounter#3 & AssocCounter#4
   llvm::SmallVector<uint32_t, 4> dstIndices;
   llvm::SmallVector<uint32_t, 4> srcIndices;
-  const auto *dstDecl = getReferencedDef(
-      collectArrayStructIndices(dstExpr, &dstIndices, /*rawIndex=*/true));
-  const auto *srcDecl = getReferencedDef(
-      collectArrayStructIndices(srcExpr, &srcIndices, /*rawIndex=*/true));
-  const auto *dstFields = declIdMapper.getCounterVarFields(dstDecl);
-  const auto *srcFields = declIdMapper.getCounterVarFields(srcDecl);
+  const auto *srcFields = getIntermediateACSBufferCounter(srcExpr, &srcIndices);
+  const auto *dstFields = getIntermediateACSBufferCounter(dstExpr, &dstIndices);
 
   if (dstFields && srcFields) {
     return dstFields->assign(*srcFields, dstIndices, srcIndices, theBuilder,
@@ -2898,13 +2899,47 @@ SPIRVEmitter::getFinalACSBufferCounter(const Expr *expr) {
 
   // AssocCounter#2: referencing some non-struct field
   llvm::SmallVector<uint32_t, 4> indices;
-  if (const auto *decl = getReferencedDef(
-          collectArrayStructIndices(expr, &indices, /*rawIndex=*/true)))
-    return declIdMapper.getCounterIdAliasPair(decl, &indices);
+
+  const auto *base =
+      collectArrayStructIndices(expr, &indices, /*rawIndex=*/true);
+  const auto *decl =
+      (base && isa<CXXThisExpr>(base))
+          ? getOrCreateDeclForMethodObject(cast<CXXMethodDecl>(curFunction))
+          : getReferencedDef(base);
+  return declIdMapper.getCounterIdAliasPair(decl, &indices);
 
   return nullptr;
 }
 
+const CounterVarFields *SPIRVEmitter::getIntermediateACSBufferCounter(
+    const Expr *expr, llvm::SmallVector<uint32_t, 4> *indices) {
+  const auto *base =
+      collectArrayStructIndices(expr, indices, /*rawIndex=*/true);
+  const auto *decl =
+      (base && isa<CXXThisExpr>(base))
+          // Use the decl we created to represent the implicit object
+          ? getOrCreateDeclForMethodObject(cast<CXXMethodDecl>(curFunction))
+          // Find the referenced decl from the original source code
+          : getReferencedDef(base);
+
+  return declIdMapper.getCounterVarFields(decl);
+}
+
+const ImplicitParamDecl *
+SPIRVEmitter::getOrCreateDeclForMethodObject(const CXXMethodDecl *method) {
+  const auto found = thisDecls.find(method);
+  if (found != thisDecls.end())
+    return found->second;
+
+  const std::string name = method->getName().str() + ".this";
+  // Create a new identifier to convey the name
+  auto &identifier = astContext.Idents.get(name);
+
+  return thisDecls[method] = ImplicitParamDecl::Create(
+             astContext, /*DC=*/nullptr, SourceLocation(), &identifier,
+             method->getThisType(astContext)->getPointeeType());
+}
+
 SpirvEvalInfo
 SPIRVEmitter::processACSBufferAppendConsume(const CXXMemberCallExpr *expr) {
   const bool isAppend = expr->getNumArgs() == 1;

+ 21 - 1
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -698,7 +698,16 @@ private:
   ///
   /// This method only handles final alias structured buffers, which means
   /// AssocCounter#1 and AssocCounter#2.
-  const CounterIdAliasPair *getFinalACSBufferCounter(const Expr *decl);
+  const CounterIdAliasPair *getFinalACSBufferCounter(const Expr *expr);
+  /// This method handles AssocCounter#3 and AssocCounter#4.
+  const CounterVarFields *
+  getIntermediateACSBufferCounter(const Expr *expr,
+                                  llvm::SmallVector<uint32_t, 4> *indices);
+
+  /// Gets or creates an ImplicitParamDecl to represent the implicit object
+  /// parameter of the given method.
+  const ImplicitParamDecl *
+  getOrCreateDeclForMethodObject(const CXXMethodDecl *method);
 
   /// \brief Loads numWords 32-bit unsigned integers or stores numWords 32-bit
   /// unsigned integers (based on the doStore parameter) to the given
@@ -843,6 +852,17 @@ private:
   /// Note: legalization specific code
   bool needsLegalization;
 
+  /// Mapping from methods to the decls to represent their implicit object
+  /// parameters
+  ///
+  /// We need this map because that we need to update the associated counters on
+  /// the implicit object when invoking method calls. The ImplicitParamDecl
+  /// mapped to serves as a key to find the associated counters in
+  /// DeclResultIdMapper.
+  ///
+  /// Note: legalization specific code
+  llvm::DenseMap<const CXXMethodDecl *, const ImplicitParamDecl *> thisDecls;
+
   /// Global variables that should be initialized once at the begining of the
   /// entry function.
   llvm::SmallVector<const VarDecl *, 4> toInitGloalVars;

+ 114 - 0
tools/clang/test/CodeGenSPIRV/spirv.legal.sbuffer.counter.method.hlsl

@@ -0,0 +1,114 @@
+// Run: %dxc -T vs_6_0 -E main
+
+struct Bundle {
+      RWStructuredBuffer<float> rw;
+  AppendStructuredBuffer<float> append;
+ ConsumeStructuredBuffer<float> consume;
+};
+
+// Counter variables for the this object of getCSBuffer()
+// CHECK: %counter_var_getCSBuffer_this_0_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_getCSBuffer_this_0_1 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_getCSBuffer_this_0_2 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_getCSBuffer_this_1_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_getCSBuffer_this_1_1 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_getCSBuffer_this_1_2 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+
+// Counter variables for the this object of getASBuffer()
+// CHECK: %counter_var_getASBuffer_this_0_0_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_getASBuffer_this_0_0_1 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_getASBuffer_this_0_0_2 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_getASBuffer_this_0_1_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_getASBuffer_this_0_1_1 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_getASBuffer_this_0_1_2 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+
+// Counter variables for the this object of getRWSBuffer()
+// CHECK: %counter_var_getRWSBuffer_this_0_0_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_getRWSBuffer_this_0_0_1 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_getRWSBuffer_this_0_0_2 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_getRWSBuffer_this_0_1_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_getRWSBuffer_this_0_1_1 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+// CHECK: %counter_var_getRWSBuffer_this_0_1_2 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
+
+struct TwoBundle {
+    Bundle b1;
+    Bundle b2;
+
+    // Checks at the end of the file
+    ConsumeStructuredBuffer<float> getCSBuffer() { return b1.consume; }
+};
+
+struct Wrapper {
+    TwoBundle b;
+
+    // Checks at the end of the file
+    AppendStructuredBuffer<float> getASBuffer() { return b.b1.append; }
+
+    // Checks at the end of the file
+    RWStructuredBuffer<float> getRWSBuffer() { return b.b2.rw; }
+};
+
+// CHECK-LABLE: %src_main = OpFunction
+float main() : VVV {
+    TwoBundle localBundle;
+    Wrapper   localWrapper;
+
+// CHECK:      [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localBundle_0_0
+// CHECK-NEXT:                    OpStore %counter_var_getCSBuffer_this_0_0 [[counter]]
+// CHECK-NEXT: [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localBundle_0_1
+// CHECK-NEXT:                    OpStore %counter_var_getCSBuffer_this_0_1 [[counter]]
+// CHECK-NEXT: [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localBundle_0_2
+// CHECK-NEXT:                    OpStore %counter_var_getCSBuffer_this_0_2 [[counter]]
+// CHECK-NEXT: [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localBundle_1_0
+// CHECK-NEXT:                    OpStore %counter_var_getCSBuffer_this_1_0 [[counter]]
+// CHECK-NEXT: [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localBundle_1_1
+// CHECK-NEXT:                    OpStore %counter_var_getCSBuffer_this_1_1 [[counter]]
+// CHECK-NEXT: [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localBundle_1_2
+// CHECK-NEXT:                    OpStore %counter_var_getCSBuffer_this_1_2 [[counter]]
+// CHECK-NEXT:                    OpFunctionCall %_ptr_Uniform_type_RWStructuredBuffer_float %TwoBundle_getCSBuffer %localBundle
+    float value = localBundle.getCSBuffer().Consume();
+
+// CHECK:      [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localWrapper_0_0_0
+// CHECK-NEXT:                    OpStore %counter_var_getASBuffer_this_0_0_0 [[counter]]
+// CHECK-NEXT: [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localWrapper_0_0_1
+// CHECK-NEXT:                    OpStore %counter_var_getASBuffer_this_0_0_1 [[counter]]
+// CHECK-NEXT: [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localWrapper_0_0_2
+// CHECK-NEXT:                    OpStore %counter_var_getASBuffer_this_0_0_2 [[counter]]
+// CHECK-NEXT: [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localWrapper_0_1_0
+// CHECK-NEXT:                    OpStore %counter_var_getASBuffer_this_0_1_0 [[counter]]
+// CHECK-NEXT: [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localWrapper_0_1_1
+// CHECK-NEXT:                    OpStore %counter_var_getASBuffer_this_0_1_1 [[counter]]
+// CHECK-NEXT: [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localWrapper_0_1_2
+// CHECK-NEXT:                    OpStore %counter_var_getASBuffer_this_0_1_2 [[counter]]
+// CHECK-NEXT:                    OpFunctionCall %_ptr_Uniform_type_RWStructuredBuffer_float %Wrapper_getASBuffer %localWrapper
+    localWrapper.getASBuffer().Append(4.2);
+
+// CHECK:      [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localWrapper_0_0_0
+// CHECK-NEXT:                    OpStore %counter_var_getRWSBuffer_this_0_0_0 [[counter]]
+// CHECK-NEXT: [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localWrapper_0_0_1
+// CHECK-NEXT:                    OpStore %counter_var_getRWSBuffer_this_0_0_1 [[counter]]
+// CHECK-NEXT: [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localWrapper_0_0_2
+// CHECK-NEXT:                    OpStore %counter_var_getRWSBuffer_this_0_0_2 [[counter]]
+// CHECK-NEXT: [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localWrapper_0_1_0
+// CHECK-NEXT:                    OpStore %counter_var_getRWSBuffer_this_0_1_0 [[counter]]
+// CHECK-NEXT: [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localWrapper_0_1_1
+// CHECK-NEXT:                    OpStore %counter_var_getRWSBuffer_this_0_1_1 [[counter]]
+// CHECK-NEXT: [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_localWrapper_0_1_2
+// CHECK-NEXT:                    OpStore %counter_var_getRWSBuffer_this_0_1_2 [[counter]]
+// CHECK-NEXT:                    OpFunctionCall %_ptr_Uniform_type_RWStructuredBuffer_float %Wrapper_getRWSBuffer %localWrapper
+    RWStructuredBuffer<float> localRWSBuffer = localWrapper.getRWSBuffer();
+
+    return localRWSBuffer[5];
+}
+
+// CHECK-LABEL: %TwoBundle_getCSBuffer = OpFunction
+// CHECK:             [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_getCSBuffer_this_0_2
+// CHECK-NEXT:                           OpStore %counter_var_getCSBuffer [[counter]]
+
+// CHECK-LABEL: %Wrapper_getASBuffer = OpFunction
+// CHECK:           [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_getASBuffer_this_0_0_1
+// CHECK-NEXT:                         OpStore %counter_var_getASBuffer [[counter]]
+
+// CHECK-LABEL: %Wrapper_getRWSBuffer = OpFunction
+// CHECK:            [[counter:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_getRWSBuffer_this_0_1_0
+// CHECK-NEXT:                          OpStore %counter_var_getRWSBuffer [[counter]]

+ 18 - 12
tools/clang/test/CodeGenSPIRV/spirv.legal.sbuffer.struct.hlsl

@@ -11,24 +11,30 @@ struct S {
     ConsumeStructuredBuffer<float4> consume;
 };
 
-// CHECK: %T = OpTypeStruct %_ptr_Uniform_type_StructuredBuffer_Basic %_ptr_Uniform_type_RWStructuredBuffer_Basic %_ptr_Uniform_type_StructuredBuffer_int
+// CHECK: %T = OpTypeStruct %_ptr_Uniform_type_StructuredBuffer_Basic %_ptr_Uniform_type_RWStructuredBuffer_Basic
 struct T {
       StructuredBuffer<Basic> ro;
     RWStructuredBuffer<Basic> rw;
-      StructuredBuffer<int>   ro2;
 
-    StructuredBuffer<Basic> getSBuffer() { return ro; }
-    StructuredBuffer<int>   getSBuffer2() { return ro2; }
 };
 
-// CHECK: %Combine = OpTypeStruct %S %T %_ptr_Uniform_type_ByteAddressBuffer %_ptr_Uniform_type_RWByteAddressBuffer
+struct U {
+    StructuredBuffer<Basic> basic;
+    StructuredBuffer<int>   integer;
+
+    StructuredBuffer<Basic> getSBufferStruct() { return basic;   }
+    StructuredBuffer<int>   getSBufferInt()    { return integer; }
+};
+
+// CHECK: %Combine = OpTypeStruct %S %T %_ptr_Uniform_type_ByteAddressBuffer %_ptr_Uniform_type_RWByteAddressBuffer %U
 struct Combine {
                       S s;
                       T t;
       ByteAddressBuffer ro;
     RWByteAddressBuffer rw;
+                      U u;
 
-    T getT() { return t; }
+    U getU() { return u; }
 };
 
        StructuredBuffer<Basic>  gSBuffer;
@@ -71,17 +77,17 @@ float4 main() : SV_Target {
 
 // Make sure that we create temporary variable for intermediate objects since
 // the function expect pointers as parameters.
-// CHECK:     [[call1:%\d+]] = OpFunctionCall %T %Combine_getT %c
-// CHECK-NEXT:                 OpStore %temp_var_T [[call1]]
-// CHECK-NEXT:[[call2:%\d+]] = OpFunctionCall %_ptr_Uniform_type_StructuredBuffer_Basic %T_getSBuffer %temp_var_T
+// CHECK:     [[call1:%\d+]] = OpFunctionCall %U %Combine_getU %c
+// CHECK-NEXT:                 OpStore %temp_var_U [[call1]]
+// CHECK-NEXT:[[call2:%\d+]] = OpFunctionCall %_ptr_Uniform_type_StructuredBuffer_Basic %U_getSBufferStruct %temp_var_U
 // CHECK-NEXT:  [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_v4float [[call2]] %int_0 %uint_10 %int_1
 // CHECK-NEXT:      {{%\d+}} = OpLoad %v4float [[ptr]]
-    float4 val = c.getT().getSBuffer()[10].b;
+    float4 val = c.getU().getSBufferStruct()[10].b;
 
 // Check StructuredBuffer of scalar type
-// CHECK:     [[call2:%\d+]] = OpFunctionCall %_ptr_Uniform_type_StructuredBuffer_int %T_getSBuffer2 %temp_var_T_0
+// CHECK:     [[call2:%\d+]] = OpFunctionCall %_ptr_Uniform_type_StructuredBuffer_int %U_getSBufferInt %temp_var_U_0
 // CHECK-NEXT:      {{%\d+}} = OpAccessChain %_ptr_Uniform_int [[call2]] %int_0 %uint_42
-    int index = c.getT().getSBuffer2()[42];
+    int index = c.getU().getSBufferInt()[42];
 
 // CHECK:      [[val:%\d+]] = OpLoad %Combine %c
 // CHECK:                     OpStore %param_var_comb [[val]]

+ 6 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -1006,9 +1006,15 @@ TEST_F(FileTest, SpirvLegalizationStructuredBufferCounter) {
               /*runValidation=*/true, /*relaxLogicalPointer=*/true);
 }
 TEST_F(FileTest, SpirvLegalizationStructuredBufferCounterInStruct) {
+  // Tests using struct/class having associated counters
   runFileTest("spirv.legal.sbuffer.counter.struct.hlsl", Expect::Success,
               /*runValidation=*/true, /*relaxLogicalPointer=*/true);
 }
+TEST_F(FileTest, SpirvLegalizationStructuredBufferCounterInMethod) {
+  // Tests using methods whose enclosing struct/class having associated counters
+  runFileTest("spirv.legal.sbuffer.counter.method.hlsl", Expect::Success,
+              /*runValidation=*/true, /*relaxLogicalPointer=*/true);
+}
 TEST_F(FileTest, SpirvLegalizationStructuredBufferInStruct) {
   runFileTest("spirv.legal.sbuffer.struct.hlsl", Expect::Success,
               /*runValidation=*/true, /*relaxLogicalPointer=*/true);