Преглед на файлове

[spirv] Legalization: support structured/byte buffer in structs (#970)

We need to change the type of these struct fields to have an extra
level of pointer.

A local resource always has void as its layout rule (because local
resource is not in the Uniform storage class). So in the TypeTranslator,
when we are trying to translate a structured/byte buffer resource that
has void layout rule, we know it must be a local resource. Then we
apply an extra level of pointer to it. Because of TypeTranslator is
recursive, that automatically handles both stand-alone local resources
and the ones in structs. 

In the SPIRVEmitter, we need to have a way to tell whether a resource
is a local resource or not because if it is a local resource, we need to
OpLoad once to get the pointer to the aliased-to global resource.
That's why we have the containsAlias field in SpirvEvalInfo. We set it to
true in getTypeForPotentialAliasVar() for local resources. And do an
extra OpLoad to get the pointer in SPIRVEmitter if it is true.
Lei Zhang преди 7 години
родител
ревизия
415e190a8b

+ 7 - 26
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -517,12 +517,10 @@ uint32_t DeclResultIdMapper::getOrRegisterFnResultId(const FunctionDecl *fn) {
 
 
   const uint32_t id = theBuilder.getSPIRVContext()->takeNextId();
   const uint32_t id = theBuilder.getSPIRVContext()->takeNextId();
   info.setResultId(id);
   info.setResultId(id);
-  if (isAlias)
-    // No need to dereference to get the pointer. Alias function returns
-    // themselves are already pointers to values.
-    info.setValTypeId(0);
-  else
-    // All other cases should be normal rvalues.
+  // No need to dereference to get the pointer. Alias function returns
+  // themselves are already pointers to values. All other cases should be
+  // normal rvalues.
+  if (!isAlias)
     info.setRValue();
     info.setRValue();
 
 
   // Create alias counter variable if suitable
   // Create alias counter variable if suitable
@@ -1882,25 +1880,18 @@ uint32_t DeclResultIdMapper::getTypeForPotentialAliasVar(
   if (const auto *varDecl = dyn_cast<VarDecl>(decl)) {
   if (const auto *varDecl = dyn_cast<VarDecl>(decl)) {
     // This method is only intended to be used to create SPIR-V variables in the
     // This method is only intended to be used to create SPIR-V variables in the
     // Function or Private storage class.
     // Function or Private storage class.
-    assert(!varDecl->isExceptionVariable() || varDecl->isStaticDataMember());
+    assert(!varDecl->isExternallyVisible() || varDecl->isStaticDataMember());
   }
   }
 
 
   const QualType type = getTypeOrFnRetType(decl);
   const QualType type = getTypeOrFnRetType(decl);
   // Whether we should generate this decl as an alias variable.
   // Whether we should generate this decl as an alias variable.
   bool genAlias = false;
   bool genAlias = false;
-  // All texture/structured/byte buffers use GLSL std430 rules.
-  LayoutRule rule = LayoutRule::GLSLStd430;
 
 
   if (const auto *buffer = dyn_cast<HLSLBufferDecl>(decl->getDeclContext())) {
   if (const auto *buffer = dyn_cast<HLSLBufferDecl>(decl->getDeclContext())) {
     // For ConstantBuffer and TextureBuffer
     // For ConstantBuffer and TextureBuffer
     if (buffer->isConstantBufferView())
     if (buffer->isConstantBufferView())
       genAlias = true;
       genAlias = true;
-    // ConstantBuffer uses GLSL std140 rules.
-    // TODO: do we actually want to include constant/texture buffers
-    // in this method?
-    if (buffer->isCBuffer())
-      rule = LayoutRule::GLSLStd140;
-  } else if (TypeTranslator::isAKindOfStructuredOrByteBuffer(type)) {
+  } else if (TypeTranslator::isOrContainsAKindOfStructuredOrByteBuffer(type)) {
     genAlias = true;
     genAlias = true;
   }
   }
 
 
@@ -1910,18 +1901,8 @@ uint32_t DeclResultIdMapper::getTypeForPotentialAliasVar(
   if (genAlias) {
   if (genAlias) {
     needsLegalization = true;
     needsLegalization = true;
 
 
-    const uint32_t valType = typeTranslator.translateType(type, rule);
-    // All constant/texture/structured/byte buffers are in the Uniform
-    // storage class.
-    const auto ptrType =
-        theBuilder.getPointerType(valType, spv::StorageClass::Uniform);
-
     if (info)
     if (info)
-      info->setStorageClass(spv::StorageClass::Uniform)
-          .setLayoutRule(rule)
-          .setValTypeId(ptrType);
-
-    return ptrType;
+      info->setContainsAliasComponent(true);
   }
   }
 
 
   return typeTranslator.translateType(type);
   return typeTranslator.translateType(type);

+ 49 - 16
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -660,17 +660,22 @@ SpirvEvalInfo SPIRVEmitter::loadIfGLValue(const Expr *expr,
       return info.setRValue();
       return info.setRValue();
     }
     }
 
 
-    uint32_t valType = 0;
-    if (valType = info.getValTypeId()) {
+    if (loadIfAliasVarRef(expr, info)) {
       // We are loading an alias variable as a whole here. This is likely for
       // We are loading an alias variable as a whole here. This is likely for
       // wholesale assignments or function returns. Need to load the pointer.
       // wholesale assignments or function returns. Need to load the pointer.
       //
       //
       // Note: legalization specific code
       // Note: legalization specific code
+      // TODO: It seems we should not set rvalue here since info is still
+      // holding a pointer. But it fails structured buffer assignment because
+      // of double loadIfGLValue() calls if we do not. Fix it.
+      return info.setRValue();
     }
     }
+
+    uint32_t valType = 0;
     // TODO: Ouch. Very hacky. We need special path to get the value type if
     // TODO: Ouch. Very hacky. We need special path to get the value type if
     // we are loading a whole ConstantBuffer/TextureBuffer since the normal
     // we are loading a whole ConstantBuffer/TextureBuffer since the normal
     // type translation path won't work.
     // type translation path won't work.
-    else if (const auto *declContext = isConstantTextureBufferDeclRef(expr)) {
+    if (const auto *declContext = isConstantTextureBufferDeclRef(expr)) {
       valType = declIdMapper.getCTBufferPushConstantTypeId(declContext);
       valType = declIdMapper.getCTBufferPushConstantTypeId(declContext);
     } else {
     } else {
       valType =
       valType =
@@ -684,18 +689,34 @@ SpirvEvalInfo SPIRVEmitter::loadIfGLValue(const Expr *expr,
 
 
 SpirvEvalInfo SPIRVEmitter::loadIfAliasVarRef(const Expr *expr) {
 SpirvEvalInfo SPIRVEmitter::loadIfAliasVarRef(const Expr *expr) {
   auto info = doExpr(expr);
   auto info = doExpr(expr);
+  loadIfAliasVarRef(expr, info);
+  return info;
+}
 
 
-  if (const auto valTypeId = info.getValTypeId()) {
-    return info
-        // Load the pointer of the aliased-to-variable
-        .setResultId(theBuilder.createLoad(valTypeId, info))
-        // Set the value's <type-id> to zero to indicate that we've performed
-        // dereference over the pointer-to-pointer and now should fallback to
-        // the normal path
-        .setValTypeId(0);
+bool SPIRVEmitter::loadIfAliasVarRef(const Expr *varExpr, SpirvEvalInfo &info) {
+  if (info.containsAliasComponent() &&
+      TypeTranslator::isAKindOfStructuredOrByteBuffer(varExpr->getType())) {
+    // Aliased-to variables are all in the Uniform storage class with GLSL
+    // std430 layout rules.
+    const auto ptrType = typeTranslator.translateType(varExpr->getType());
+
+    // Load the pointer of the aliased-to-variable if the expression has a
+    // pointer to pointer type. That is, the expression itself is a lvalue.
+    // (Note that we translate alias function return values as pointer types,
+    // not pointer to pointer types.)
+    if (varExpr->isGLValue())
+      info.setResultId(theBuilder.createLoad(ptrType, info));
+
+    info.setStorageClass(spv::StorageClass::Uniform)
+        .setLayoutRule(LayoutRule::GLSLStd430)
+        // Set to false to indicate that we've performed dereference over the
+        // pointer-to-pointer and now should fallback to the normal path
+        .setContainsAliasComponent(false);
+
+    return true;
   }
   }
 
 
-  return info;
+  return false;
 }
 }
 
 
 uint32_t SPIRVEmitter::castToType(uint32_t value, QualType fromType,
 uint32_t SPIRVEmitter::castToType(uint32_t value, QualType fromType,
@@ -977,7 +998,7 @@ void SPIRVEmitter::doVarDecl(const VarDecl *decl) {
       else
       else
         storeValue(varId, loadIfGLValue(init), decl->getType());
         storeValue(varId, loadIfGLValue(init), decl->getType());
 
 
-      // Update counter variable associatd with local variables
+      // Update counter variable associated with local variables
       tryToAssignCounterVar(decl, init);
       tryToAssignCounterVar(decl, init);
     }
     }
 
 
@@ -1427,7 +1448,7 @@ void SPIRVEmitter::doIfStmt(const IfStmt *ifStmt) {
 
 
 void SPIRVEmitter::doReturnStmt(const ReturnStmt *stmt) {
 void SPIRVEmitter::doReturnStmt(const ReturnStmt *stmt) {
   if (const auto *retVal = stmt->getRetValue()) {
   if (const auto *retVal = stmt->getRetValue()) {
-    // Update counter variable associatd with function returns
+    // Update counter variable associated with function returns
     tryToAssignCounterVar(curFunction, retVal);
     tryToAssignCounterVar(curFunction, retVal);
 
 
     const auto retInfo = doExpr(retVal);
     const auto retInfo = doExpr(retVal);
@@ -1559,7 +1580,7 @@ SpirvEvalInfo SPIRVEmitter::doBinaryOperator(const BinaryOperator *expr) {
   // For other binary operations, we need to evaluate lhs before rhs.
   // For other binary operations, we need to evaluate lhs before rhs.
   if (opcode == BO_Assign) {
   if (opcode == BO_Assign) {
     if (const auto *dstDecl = getReferencedDef(expr->getLHS()))
     if (const auto *dstDecl = getReferencedDef(expr->getLHS()))
-      // Update counter variable associatd with lhs of assignments
+      // Update counter variable associated with lhs of assignments
       tryToAssignCounterVar(dstDecl, expr->getRHS());
       tryToAssignCounterVar(dstDecl, expr->getRHS());
 
 
     return processAssignment(expr->getLHS(), loadIfGLValue(expr->getRHS()),
     return processAssignment(expr->getLHS(), loadIfGLValue(expr->getRHS()),
@@ -4645,6 +4666,18 @@ const Expr *SPIRVEmitter::collectArrayStructIndices(
       const auto thisBaseType = thisBase->getType();
       const auto thisBaseType = thisBase->getType();
       const Expr *base = collectArrayStructIndices(thisBase, indices);
       const Expr *base = collectArrayStructIndices(thisBase, indices);
 
 
+      if (thisBaseType != base->getType() &&
+          TypeTranslator::isAKindOfStructuredOrByteBuffer(thisBaseType)) {
+        // The immediate base is a kind of structured or byte buffer. It should
+        // be an alias variable. Break the normal index collecting chain.
+        // Return the immediate base as the base so that we can apply other
+        // hacks for legalization over it.
+        //
+        // Note: legalization specific code
+        indices->clear();
+        base = thisBase;
+      }
+
       // If the base is a StructureType, we need to push an addtional index 0
       // If the base is a StructureType, we need to push an addtional index 0
       // here. This is because we created an additional OpTypeRuntimeArray
       // here. This is because we created an additional OpTypeRuntimeArray
       // in the structure.
       // in the structure.
@@ -7174,7 +7207,7 @@ bool SPIRVEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
     if (const auto *init = varDecl->getInit()) {
     if (const auto *init = varDecl->getInit()) {
       storeValue(varInfo, doExpr(init), varDecl->getType());
       storeValue(varInfo, doExpr(init), varDecl->getType());
 
 
-      // Update counter variable associatd with global variables
+      // Update counter variable associated with global variables
       tryToAssignCounterVar(varDecl, init);
       tryToAssignCounterVar(varDecl, init);
     } else {
     } else {
       const auto typeId = typeTranslator.translateType(varDecl->getType());
       const auto typeId = typeTranslator.translateType(varDecl->getType());

+ 7 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -118,6 +118,13 @@ private:
   /// Note: legalization specific code
   /// Note: legalization specific code
   SpirvEvalInfo loadIfAliasVarRef(const Expr *expr);
   SpirvEvalInfo loadIfAliasVarRef(const Expr *expr);
 
 
+  /// Loads the pointer of the aliased-to-variable and ajusts aliasVarInfo
+  /// accordingly if aliasVarExpr is referencing an alias variable. Returns true
+  /// if aliasVarInfo is changed, false otherwise.
+  ///
+  /// Note: legalization specific code
+  bool loadIfAliasVarRef(const Expr *aliasVarExpr, SpirvEvalInfo &aliasVarInfo);
+
 private:
 private:
   /// Translates the given frontend binary operator into its SPIR-V equivalent
   /// Translates the given frontend binary operator into its SPIR-V equivalent
   /// taking consideration of the operand type.
   /// taking consideration of the operand type.

+ 13 - 12
tools/clang/lib/SPIRV/SpirvEvalInfo.h

@@ -79,8 +79,8 @@ public:
   /// Handly implicit conversion to test whether the <result-id> is valid.
   /// Handly implicit conversion to test whether the <result-id> is valid.
   operator bool() const { return resultId != 0; }
   operator bool() const { return resultId != 0; }
 
 
-  inline SpirvEvalInfo &setValTypeId(uint32_t id);
-  uint32_t getValTypeId() const { return valTypeId; }
+  inline SpirvEvalInfo &setContainsAliasComponent(bool);
+  bool containsAliasComponent() const { return containsAlias; }
 
 
   inline SpirvEvalInfo &setStorageClass(spv::StorageClass sc);
   inline SpirvEvalInfo &setStorageClass(spv::StorageClass sc);
   spv::StorageClass getStorageClass() const { return storageClass; }
   spv::StorageClass getStorageClass() const { return storageClass; }
@@ -99,14 +99,15 @@ public:
 
 
 private:
 private:
   uint32_t resultId;
   uint32_t resultId;
-  /// The value's <type-id> for this variable.
+  /// Indicates whether this evaluation result contains alias variables
   ///
   ///
-  /// This field should only be non-zero for original alias variables, which is
-  /// of pointer-to-pointer type. After dereferencing the alias variable, this
-  /// should be set to zero to let CodeGen fall back to normal handling path.
+  /// This field should only be true for stand-alone alias variables, which is
+  /// of pointer-to-pointer type, or struct variables containing alias fields.
+  /// After dereferencing the alias variable, this should be set to false to let
+  /// CodeGen fall back to normal handling path.
   ///
   ///
   /// Note: legalization specific code
   /// Note: legalization specific code
-  uint32_t valTypeId;
+  bool containsAlias;
 
 
   spv::StorageClass storageClass;
   spv::StorageClass storageClass;
   LayoutRule layoutRule;
   LayoutRule layoutRule;
@@ -117,9 +118,9 @@ private:
 };
 };
 
 
 SpirvEvalInfo::SpirvEvalInfo(uint32_t id)
 SpirvEvalInfo::SpirvEvalInfo(uint32_t id)
-    : resultId(id), valTypeId(0), storageClass(spv::StorageClass::Function),
-      layoutRule(LayoutRule::Void), isRValue_(false), isConstant_(false),
-      isRelaxedPrecision_(false) {}
+    : resultId(id), containsAlias(false),
+      storageClass(spv::StorageClass::Function), layoutRule(LayoutRule::Void),
+      isRValue_(false), isConstant_(false), isRelaxedPrecision_(false) {}
 
 
 SpirvEvalInfo &SpirvEvalInfo::setResultId(uint32_t id) {
 SpirvEvalInfo &SpirvEvalInfo::setResultId(uint32_t id) {
   resultId = id;
   resultId = id;
@@ -132,8 +133,8 @@ SpirvEvalInfo SpirvEvalInfo::substResultId(uint32_t newId) const {
   return info;
   return info;
 }
 }
 
 
-SpirvEvalInfo &SpirvEvalInfo::setValTypeId(uint32_t id) {
-  valTypeId = id;
+SpirvEvalInfo &SpirvEvalInfo::setContainsAliasComponent(bool contains) {
+  containsAlias = contains;
   return *this;
   return *this;
 }
 }
 
 

+ 50 - 4
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -488,6 +488,22 @@ bool TypeTranslator::isAKindOfStructuredOrByteBuffer(QualType type) {
   return false;
   return false;
 }
 }
 
 
+bool TypeTranslator::isOrContainsAKindOfStructuredOrByteBuffer(QualType type) {
+  if (const RecordType *recordType = type->getAs<RecordType>()) {
+    StringRef name = recordType->getDecl()->getName();
+    if (name == "StructuredBuffer" || name == "RWStructuredBuffer" ||
+        name == "ByteAddressBuffer" || name == "RWByteAddressBuffer" ||
+        name == "AppendStructuredBuffer" || name == "ConsumeStructuredBuffer")
+      return true;
+
+    for (const auto *field : recordType->getDecl()->fields()) {
+      if (isOrContainsAKindOfStructuredOrByteBuffer(field->getType()))
+        return true;
+    }
+  }
+  return false;
+}
+
 bool TypeTranslator::isStructuredBuffer(QualType type) {
 bool TypeTranslator::isStructuredBuffer(QualType type) {
   const auto *recordType = type->getAs<RecordType>();
   const auto *recordType = type->getAs<RecordType>();
   if (!recordType)
   if (!recordType)
@@ -836,10 +852,20 @@ uint32_t TypeTranslator::translateResourceType(QualType type, LayoutRule rule) {
 
 
   if (name == "StructuredBuffer" || name == "RWStructuredBuffer" ||
   if (name == "StructuredBuffer" || name == "RWStructuredBuffer" ||
       name == "AppendStructuredBuffer" || name == "ConsumeStructuredBuffer") {
       name == "AppendStructuredBuffer" || name == "ConsumeStructuredBuffer") {
-    auto &context = *theBuilder.getSPIRVContext();
     // StructureBuffer<S> will be translated into an OpTypeStruct with one
     // StructureBuffer<S> will be translated into an OpTypeStruct with one
     // field, which is an OpTypeRuntimeArray of OpTypeStruct (S).
     // field, which is an OpTypeRuntimeArray of OpTypeStruct (S).
 
 
+    // If layout rule is void, it means these resource types are used for
+    // declaring local resources, which should be created as alias variables.
+    // The aliased-to variable should surely be in the Uniform storage class,
+    // which has layout decorations.
+    bool asAlias = false;
+    if (rule == LayoutRule::Void) {
+      asAlias = true;
+      rule = LayoutRule::GLSLStd430;
+    }
+
+    auto &context = *theBuilder.getSPIRVContext();
     const auto s = hlsl::GetHLSLResourceResultType(type);
     const auto s = hlsl::GetHLSLResourceResultType(type);
     const uint32_t structType = translateType(s, rule);
     const uint32_t structType = translateType(s, rule);
     std::string structName;
     std::string structName;
@@ -864,16 +890,36 @@ uint32_t TypeTranslator::translateResourceType(QualType type, LayoutRule rule) {
       decorations.push_back(Decoration::getNonWritable(context, 0));
       decorations.push_back(Decoration::getNonWritable(context, 0));
     decorations.push_back(Decoration::getBufferBlock(context));
     decorations.push_back(Decoration::getBufferBlock(context));
     const std::string typeName = "type." + name.str() + "." + structName;
     const std::string typeName = "type." + name.str() + "." + structName;
-    return theBuilder.getStructType(raType, typeName, {}, decorations);
+    const auto valType =
+        theBuilder.getStructType(raType, typeName, {}, decorations);
+
+    if (asAlias) {
+      // All structured buffers are in the Uniform storage class.
+      return theBuilder.getPointerType(valType, spv::StorageClass::Uniform);
+    } else {
+      return valType;
+    }
   }
   }
 
 
   // ByteAddressBuffer types.
   // ByteAddressBuffer types.
   if (name == "ByteAddressBuffer") {
   if (name == "ByteAddressBuffer") {
-    return theBuilder.getByteAddressBufferType(/*isRW*/ false);
+    const auto bufferType = theBuilder.getByteAddressBufferType(/*isRW*/ false);
+    if (rule == LayoutRule::Void) {
+      // All byte address buffers are in the Uniform storage class.
+      return theBuilder.getPointerType(bufferType, spv::StorageClass::Uniform);
+    } else {
+      return bufferType;
+    }
   }
   }
   // RWByteAddressBuffer types.
   // RWByteAddressBuffer types.
   if (name == "RWByteAddressBuffer") {
   if (name == "RWByteAddressBuffer") {
-    return theBuilder.getByteAddressBufferType(/*isRW*/ true);
+    const auto bufferType = theBuilder.getByteAddressBufferType(/*isRW*/ true);
+    if (rule == LayoutRule::Void) {
+      // All byte address buffers are in the Uniform storage class.
+      return theBuilder.getPointerType(bufferType, spv::StorageClass::Uniform);
+    } else {
+      return bufferType;
+    }
   }
   }
 
 
   // Buffer and RWBuffer types
   // Buffer and RWBuffer types

+ 5 - 0
tools/clang/lib/SPIRV/TypeTranslator.h

@@ -95,6 +95,11 @@ public:
   /// (RW)ByteAddressBuffer, or {Append|Consume}StructuredBuffer.
   /// (RW)ByteAddressBuffer, or {Append|Consume}StructuredBuffer.
   static bool isAKindOfStructuredOrByteBuffer(QualType type);
   static bool isAKindOfStructuredOrByteBuffer(QualType type);
 
 
+  /// \brief Returns true if the given type is the HLSL (RW)StructuredBuffer,
+  /// (RW)ByteAddressBuffer, {Append|Consume}StructuredBuffer, or a struct
+  /// containing one of the above.
+  static bool isOrContainsAKindOfStructuredOrByteBuffer(QualType type);
+
   /// \brief Returns true if the given type is the HLSL Buffer type.
   /// \brief Returns true if the given type is the HLSL Buffer type.
   static bool isBuffer(QualType type);
   static bool isBuffer(QualType type);
 
 

+ 86 - 0
tools/clang/test/CodeGenSPIRV/spirv.legal.sbuffer.struct.hlsl

@@ -0,0 +1,86 @@
+// Run: %dxc -T ps_6_0 -E main
+
+struct Basic {
+    float3 a;
+    float4 b;
+};
+
+// CHECK: %S = OpTypeStruct %_ptr_Uniform_type_AppendStructuredBuffer_v4float %_ptr_Uniform_type_AppendStructuredBuffer_v4float
+struct S {
+     AppendStructuredBuffer<float4> append;
+    ConsumeStructuredBuffer<float4> consume;
+};
+
+// CHECK: %T = OpTypeStruct %_ptr_Uniform_type_StructuredBuffer_Basic %_ptr_Uniform_type_RWStructuredBuffer_Basic
+struct T {
+      StructuredBuffer<Basic> ro;
+    RWStructuredBuffer<Basic> rw;
+};
+
+// CHECK: %Combine = OpTypeStruct %S %T %_ptr_Uniform_type_ByteAddressBuffer %_ptr_Uniform_type_RWByteAddressBuffer
+struct Combine {
+                      S s;
+                      T t;
+      ByteAddressBuffer ro;
+    RWByteAddressBuffer rw;
+};
+
+       StructuredBuffer<Basic>  gSBuffer;
+     RWStructuredBuffer<Basic>  gRWSBuffer;
+ AppendStructuredBuffer<float4> gASBuffer;
+ConsumeStructuredBuffer<float4> gCSBuffer;
+      ByteAddressBuffer         gBABuffer;
+    RWByteAddressBuffer         gRWBABuffer;
+
+float4 foo(Combine comb);
+
+float4 main() : SV_Target {
+    Combine c;
+
+// CHECK:      [[ptr:%\d+]] = OpAccessChain %_ptr_Function__ptr_Uniform_type_AppendStructuredBuffer_v4float %c %int_0 %int_0
+// CHECK-NEXT:                OpStore [[ptr]] %gASBuffer
+    c.s.append = gASBuffer;
+// CHECK:      [[ptr:%\d+]] = OpAccessChain %_ptr_Function__ptr_Uniform_type_AppendStructuredBuffer_v4float %c %int_0 %int_1
+// CHECK-NEXT:                OpStore [[ptr]] %gCSBuffer
+    c.s.consume = gCSBuffer;
+
+    T t;
+// CHECK:      [[ptr:%\d+]] = OpAccessChain %_ptr_Function__ptr_Uniform_type_StructuredBuffer_Basic %t %int_0
+// CHECK-NEXT:                OpStore [[ptr]] %gSBuffer
+    t.ro = gSBuffer;
+// CHECK:      [[ptr:%\d+]] = OpAccessChain %_ptr_Function__ptr_Uniform_type_RWStructuredBuffer_Basic %t %int_1
+// CHECK-NEXT:                OpStore [[ptr]] %gRWSBuffer
+    t.rw = gRWSBuffer;
+// CHECK:      [[val:%\d+]] = OpLoad %T %t
+// CHECK-NEXT: [[ptr:%\d+]] = OpAccessChain %_ptr_Function_T %c %int_1
+// CHECK-NEXT:                OpStore [[ptr]] [[val]]
+    c.t = t;
+
+// CHECK:      [[ptr:%\d+]] = OpAccessChain %_ptr_Function__ptr_Uniform_type_ByteAddressBuffer %c %int_2
+// CHECK-NEXT:                OpStore [[ptr]] %gBABuffer
+    c.ro = gBABuffer;
+// CHECK:      [[ptr:%\d+]] = OpAccessChain %_ptr_Function__ptr_Uniform_type_RWByteAddressBuffer %c %int_3
+// CHECK-NEXT:                OpStore [[ptr]] %gRWBABuffer
+    c.rw = gRWBABuffer;
+
+// CHECK:      [[val:%\d+]] = OpLoad %Combine %c
+// CHECK-NEXT:                OpStore %param_var_comb [[val]]
+    return foo(c);
+}
+float4 foo(Combine comb) {
+    // TODO: add support for associated counters of struct fields
+    // comb.s.append.Append(float4(1, 2, 3, 4));
+    // float4 val = comb.s.consume.Consume();
+    // comb.t.rw[5].a = 4.2;
+
+// CHECK:      [[ptr1:%\d+]] = OpAccessChain %_ptr_Function__ptr_Uniform_type_ByteAddressBuffer %comb %int_2
+// CHECK-NEXT: [[ptr2:%\d+]] = OpLoad %_ptr_Uniform_type_ByteAddressBuffer [[ptr1]]
+// CHECK-NEXT:  [[idx:%\d+]] = OpShiftRightLogical %uint %uint_5 %uint_2
+// CHECK-NEXT:      {{%\d+}} = OpAccessChain %_ptr_Uniform_uint [[ptr2]] %uint_0 [[idx]]
+    uint val = comb.ro.Load(5);
+
+// CHECK:      [[ptr1:%\d+]] = OpAccessChain %_ptr_Function__ptr_Uniform_type_StructuredBuffer_Basic %comb %int_1 %int_0
+// CHECK-NEXT: [[ptr2:%\d+]] = OpLoad %_ptr_Uniform_type_StructuredBuffer_Basic [[ptr1]]
+// CHECK-NEXT:      {{%\d+}} = OpAccessChain %_ptr_Uniform_v4float [[ptr2]] %int_0 %uint_0 %int_1
+    return comb.t.ro[0].b;
+}

+ 5 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -1003,6 +1003,11 @@ TEST_F(FileTest, SpirvLegalizationStructuredBufferCounter) {
               // The generated SPIR-V needs legalization.
               // The generated SPIR-V needs legalization.
               /*runValidation=*/false);
               /*runValidation=*/false);
 }
 }
+TEST_F(FileTest, SpirvLegalizationStructuredBufferInStruct) {
+  runFileTest("spirv.legal.sbuffer.struct.hlsl", Expect::Success,
+              // The generated SPIR-V needs legalization.
+              /*runValidation=*/false);
+}
 TEST_F(FileTest, SpirvLegalizationConstantBuffer) {
 TEST_F(FileTest, SpirvLegalizationConstantBuffer) {
   runFileTest("spirv.legal.cbuffer.hlsl");
   runFileTest("spirv.legal.cbuffer.hlsl");
 }
 }