Browse Source

[spirv] Emit alias structured/byte buffer as pointer to pointer (#918)

This is to aid legalization passes to do their work.

For non-external variables of the following types

* (RW)StructuredBuffer
* (RW)ByteAddressBuffer
* AppendStructuredBuffer
* ConsumeStructuredBuffer

An extra level of pointer is applied. We use this extra level of
pointer to indicate that they are aliases. Loads and stores of
these alias variables will get the pointers of the aliased-to-
variables.

Associated counters is not handled in this commit.
Lei Zhang 7 years ago
parent
commit
eaa3e8e26f

+ 61 - 6
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -244,30 +244,42 @@ SpirvEvalInfo DeclResultIdMapper::getDeclResultId(const ValueDecl *decl,
 }
 
 uint32_t DeclResultIdMapper::createFnParam(const ParmVarDecl *param) {
-  const uint32_t type = typeTranslator.translateType(param->getType());
+  bool isAlias = false;
+  const uint32_t type = getTypeForPotentialAliasVar(param, &isAlias);
   const uint32_t ptrType =
       theBuilder.getPointerType(type, spv::StorageClass::Function);
   const uint32_t id = theBuilder.addFnParam(ptrType, param->getName());
-  astDecls[param] = SpirvEvalInfo(id);
+  astDecls[param] = SpirvEvalInfo(id)
+                        .setStorageClass(isAlias ? spv::StorageClass::Uniform
+                                                 : spv::StorageClass::Function)
+                        .setValTypeId(isAlias ? type : 0);
 
   return id;
 }
 
 uint32_t DeclResultIdMapper::createFnVar(const VarDecl *var,
                                          llvm::Optional<uint32_t> init) {
-  const uint32_t type = typeTranslator.translateType(var->getType());
+  bool isAlias = false;
+  const uint32_t type = getTypeForPotentialAliasVar(var, &isAlias);
   const uint32_t id = theBuilder.addFnVar(type, var->getName(), init);
-  astDecls[var] = SpirvEvalInfo(id);
+  astDecls[var] = SpirvEvalInfo(id)
+                      .setStorageClass(isAlias ? spv::StorageClass::Uniform
+                                               : spv::StorageClass::Function)
+                      .setValTypeId(isAlias ? type : 0);
 
   return id;
 }
 
 uint32_t DeclResultIdMapper::createFileVar(const VarDecl *var,
                                            llvm::Optional<uint32_t> init) {
-  const uint32_t type = typeTranslator.translateType(var->getType());
+  bool isAlias = false;
+  const uint32_t type = getTypeForPotentialAliasVar(var, &isAlias);
   const uint32_t id = theBuilder.addModuleVar(type, spv::StorageClass::Private,
                                               var->getName(), init);
-  astDecls[var] = SpirvEvalInfo(id).setStorageClass(spv::StorageClass::Private);
+  astDecls[var] = SpirvEvalInfo(id)
+                      .setStorageClass(isAlias ? spv::StorageClass::Uniform
+                                               : spv::StorageClass::Private)
+                      .setValTypeId(isAlias ? type : 0);
 
   return id;
 }
@@ -1792,5 +1804,48 @@ DeclResultIdMapper::getStorageClassForSigPoint(const hlsl::SigPoint *sigPoint) {
   return sc;
 }
 
+uint32_t
+DeclResultIdMapper::getTypeForPotentialAliasVar(const DeclaratorDecl *decl,
+                                                bool *shouldBeAlias) {
+  if (const auto *varDecl = dyn_cast<VarDecl>(decl)) {
+    // This method is only intended to be used to create SPIR-V variables in the
+    // Function or Private storage class.
+    assert(!varDecl->isExceptionVariable() || varDecl->isStaticDataMember());
+  }
+
+  const QualType type = getTypeOrFnRetType(decl);
+  // Whether we should generate this decl as an alias variable.
+  bool genAlias = false;
+  // All texture/structured/byte buffers use GLSL std430 rules.
+  LayoutRule rule = LayoutRule::GLSLStd430;
+
+  if (const auto *buffer = dyn_cast<HLSLBufferDecl>(decl->getDeclContext())) {
+    // For ConstantBuffer and TextureBuffer
+    if (buffer->isConstantBufferView())
+      genAlias = true;
+    // ConstantBuffer uses GLSL std140 rules.
+    // TODO: do we actually want to include constant/texture buffers
+    // in this method?
+    if (buffer->isCBuffer())
+      rule = LayoutRule::GLSLStd140;
+  } else if (TypeTranslator::isAKindOfStructuredOrByteBuffer(type)) {
+    genAlias = true;
+  }
+
+  if (shouldBeAlias)
+    *shouldBeAlias = genAlias;
+
+  if (genAlias) {
+    needsLegalization = true;
+
+    const uint32_t valType = typeTranslator.translateType(type, rule);
+    // All constant/texture/structured/byte buffers are in the Uniform
+    // storage class.
+    return theBuilder.getPointerType(valType, spv::StorageClass::Uniform);
+  }
+
+  return typeTranslator.translateType(type);
+}
+
 } // end namespace spirv
 } // end namespace clang

+ 62 - 0
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -214,6 +214,15 @@ public:
   /// \brief Creates a PushConstant block from the given decl.
   uint32_t createPushConstant(const VarDecl *decl);
 
+  /// \brief Returns the suitable type for the given decl, considering the
+  /// given decl could possibly be created as an alias variable. If true, a
+  /// pointer-to-the-value type will be returned, otherwise, just return the
+  /// normal value type.
+  ///
+  /// Note: legalization specific code
+  uint32_t getTypeForPotentialAliasVar(const DeclaratorDecl *var,
+                                       bool *shouldBeAlias = nullptr);
+
   /// \brief Sets the <result-id> of the entry function.
   void setEntryFunctionId(uint32_t id) { entryFunctionId = id; }
 
@@ -300,6 +309,8 @@ public:
   /// module under construction.
   bool decorateResourceBindings();
 
+  bool requiresLegalization() const { return needsLegalization; }
+
 private:
   /// \brief Wrapper method to create a fatal error message and report it
   /// in the diagnostic engine associated with this consumer.
@@ -476,6 +487,56 @@ private:
   /// to the <type-id>
   llvm::DenseMap<const DeclContext *, uint32_t> ctBufferPCTypeIds;
 
+  /// Whether the translated SPIR-V binary needs legalization.
+  ///
+  /// The following cases will require legalization:
+  ///
+  /// 1. Opaque types (textures, samplers) within structs
+  /// 2. Structured buffer assignments
+  ///
+  /// This covers the second case:
+  ///
+  /// When we have a kind of structured or byte buffer, meaning one of the
+  /// following
+  ///
+  /// * StructuredBuffer
+  /// * RWStructuredBuffer
+  /// * AppendStructuredBuffer
+  /// * ConsumeStructuredBuffer
+  /// * ByteAddressStructuredBuffer
+  /// * RWByteAddressStructuredBuffer
+  ///
+  /// and assigning to them (using operator=, passing in as function parameter,
+  /// returning as function return), we need legalization.
+  ///
+  /// All variable definitions (including static/non-static local/global
+  /// variables, function parameters/returns) will gain another level of
+  /// pointerness, unless they will generate externally visible SPIR-V
+  /// variables. So variables and parameters will be of pointer-to-pointer type,
+  /// while function returns will be of pointer type. We adopt this mechanism to
+  /// convey to the legalization passes that they are *alias* variables, and
+  /// all accesses should happen to the aliased-to-variables. Loading such an
+  /// alias variable will give the pointer to the aliased-to-variable, while
+  /// storing into such an alias variable should write the pointer to the
+  /// aliased-to-variable.
+  ///
+  /// Based on the above, CodeGen should take care of the following AST nodes:
+  ///
+  /// * Definition of alias variables: should add another level of pointers
+  /// * Assigning non-alias variables to alias variables: should avoid the load
+  ///   over the non-alias variables
+  /// * Accessing alias variables: should load the pointer first and then
+  ///   further compose access chains.
+  ///
+  /// Note that the associated counters bring about their own complication.
+  /// We also need to apply the alias mechanism for them.
+  ///
+  /// If this is true, SPIRV-Tools legalization passes will be executed after
+  /// the translation to legalize the generated SPIR-V binary.
+  ///
+  /// Note: legalization specific code
+  bool needsLegalization;
+
 public:
   /// The gl_PerVertex structs for both input and output
   GlPerVertex glPerVertex;
@@ -488,6 +549,7 @@ DeclResultIdMapper::DeclResultIdMapper(const hlsl::ShaderModel &model,
     : shaderModel(model), theBuilder(builder), spirvOptions(options),
       astContext(context), diags(context.getDiagnostics()),
       typeTranslator(context, builder, diags), entryFunctionId(0),
+      needsLegalization(false),
       glPerVertex(model, context, builder, typeTranslator) {}
 
 bool DeclResultIdMapper::decorateStageIOLocations() {

+ 64 - 19
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -169,9 +169,9 @@ const Expr *isStructuredBufferLoad(const Expr *expr, const Expr **index) {
 }
 
 /// Returns true if the given VarDecl will be translated into a SPIR-V variable
-/// in Private or Function storage class.
-inline bool isNonExternalVar(const VarDecl *var) {
-  return !var->isExternallyVisible() || var->isStaticDataMember();
+/// not in the Private or Function storage class.
+inline bool isExternalVar(const VarDecl *var) {
+  return var->isExternallyVisible() && !var->isStaticDataMember();
 }
 
 /// Returns the referenced variable's DeclContext if the given expr is
@@ -205,6 +205,19 @@ bool isConstantTextureBufferLoad(const Expr *expr) {
   return false;
 }
 
+/// Returns true if the given expr is an DeclRefExpr referencing a kind of
+/// structured or byte buffer and it is non-alias one.
+///
+/// Note: legalization specific code
+bool isReferencingNonAliasStructuredOrByteBuffer(const Expr *expr) {
+  if (const auto *declRefExpr = dyn_cast<DeclRefExpr>(expr->IgnoreParenCasts()))
+    if (const auto *varDecl = dyn_cast<VarDecl>(declRefExpr->getFoundDecl()))
+      if (TypeTranslator::isAKindOfStructuredOrByteBuffer(varDecl->getType())) {
+        return isExternalVar(varDecl);
+      }
+  return false;
+}
+
 bool spirvToolsLegalize(std::vector<uint32_t> *module, std::string *messages) {
   spvtools::Optimizer optimizer(SPV_ENV_VULKAN_1_0);
 
@@ -395,7 +408,7 @@ void SPIRVEmitter::HandleTranslationUnit(ASTContext &context) {
 
   if (!spirvOptions.codeGenHighLevel) {
     // Run legalization passes
-    if (needsLegalization) {
+    if (needsLegalization || declIdMapper.requiresLegalization()) {
       std::string messages;
       if (!spirvToolsLegalize(&m, &messages)) {
         emitFatalError("failed to legalize SPIR-V: %0", {}) << messages;
@@ -586,11 +599,26 @@ SpirvEvalInfo SPIRVEmitter::loadIfGLValue(const Expr *expr) {
   auto info = doExpr(expr);
 
   if (!info.isRValue()) {
+    // Check whether we are trying to load an externally visible structured/byte
+    // buffer as a whole. If true, it means we are creating alias for it. Avoid
+    // the load and write the pointer directly to the alias variable then.
+    //
+    // Note: legalization specific code
+    if (isReferencingNonAliasStructuredOrByteBuffer(expr)) {
+      return info.setRValue();
+    }
+
     uint32_t valType = 0;
+    if (valType = info.getValTypeId()) {
+      // We are loading an alias variable as a whole here. This is likely for
+      // wholesale assignments or function returns. Need to load the pointer.
+      //
+      // Note: legalization specific code
+    }
     // TODO: Ouch. Very hacky. We need special path to get the value type if
     // we are loading a whole ConstantBuffer/TextureBuffer since the normal
     // type translation path won't work.
-    if (const auto *declContext = isConstantTextureBufferDeclRef(expr)) {
+    else if (const auto *declContext = isConstantTextureBufferDeclRef(expr)) {
       valType = declIdMapper.getCTBufferPushConstantTypeId(declContext);
     } else {
       valType =
@@ -602,6 +630,22 @@ SpirvEvalInfo SPIRVEmitter::loadIfGLValue(const Expr *expr) {
   return info;
 }
 
+SpirvEvalInfo SPIRVEmitter::loadIfAliasVarRef(const Expr *expr) {
+  auto info = doExpr(expr);
+
+  if (const auto valTypeId = info.getValTypeId()) {
+    return info
+        // Load the pointer of the aliased-to-variable
+        .setResultId(theBuilder.createLoad(valTypeId, info))
+        // Set the value's <type-id> to zero to indicate that we've performed
+        // dereference over the pointer-to-pointer and now should fallback to
+        // the normal path
+        .setValTypeId(0);
+  }
+
+  return info;
+}
+
 uint32_t SPIRVEmitter::castToType(uint32_t value, QualType fromType,
                                   QualType toType, SourceLocation srcLoc) {
   if (isFloatOrVecOfFloatType(toType))
@@ -672,7 +716,7 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
       TypeTranslator::isOpaqueStructType(decl->getReturnType()))
     needsLegalization = true;
 
-  const uint32_t retType = typeTranslator.translateType(decl->getReturnType());
+  const uint32_t retType = declIdMapper.getTypeForPotentialAliasVar(decl);
 
   // Construct the function signature.
   llvm::SmallVector<uint32_t, 4> paramTypes;
@@ -697,7 +741,7 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
   }
 
   for (const auto *param : decl->params()) {
-    const uint32_t valueType = typeTranslator.translateType(param->getType());
+    const uint32_t valueType = declIdMapper.getTypeForPotentialAliasVar(param);
     const uint32_t ptrType =
         theBuilder.getPointerType(valueType, spv::StorageClass::Function);
     paramTypes.push_back(ptrType);
@@ -851,7 +895,9 @@ void SPIRVEmitter::doVarDecl(const VarDecl *decl) {
   // File scope variables (static "global" and "local" variables) belongs to
   // the Private storage class, while function scope variables (normal "local"
   // variables) belongs to the Function storage class.
-  if (isNonExternalVar(decl)) {
+  if (isExternalVar(decl)) {
+    varId = declIdMapper.createExternVar(decl);
+  } else {
     // We already know the variable is not externally visible here. If it does
     // not have local storage, it should be file scope variable.
     const bool isFileScopeVar = !decl->hasLocalStorage();
@@ -882,8 +928,6 @@ void SPIRVEmitter::doVarDecl(const VarDecl *decl) {
       else
         storeValue(varId, loadIfGLValue(init), decl->getType(), init);
     }
-  } else {
-    varId = declIdMapper.createExternVar(decl);
   }
 
   if (TypeTranslator::isRelaxedPrecisionType(decl->getType())) {
@@ -1551,7 +1595,7 @@ SpirvEvalInfo SPIRVEmitter::processCall(const CallExpr *callExpr) {
 
     // We need to create variables for holding the values to be used as
     // arguments. The variables themselves are of pointer types.
-    const uint32_t varType = typeTranslator.translateType(arg->getType());
+    const uint32_t varType = declIdMapper.getTypeForPotentialAliasVar(param);
     const std::string varName = "param.var." + param->getNameAsString();
     const uint32_t tempVarId = theBuilder.addFnVar(varType, varName);
 
@@ -1575,7 +1619,7 @@ SpirvEvalInfo SPIRVEmitter::processCall(const CallExpr *callExpr) {
     workQueue.insert(callee);
   }
 
-  const uint32_t retType = typeTranslator.translateType(callExpr->getType());
+  const uint32_t retType = declIdMapper.getTypeForPotentialAliasVar(callee);
   // Get or forward declare the function <result-id>
   const uint32_t funcId = declIdMapper.getOrRegisterFnResultId(callee);
 
@@ -1995,7 +2039,7 @@ SPIRVEmitter::doConditionalOperator(const ConditionalOperator *expr) {
 uint32_t SPIRVEmitter::processByteAddressBufferStructuredBufferGetDimensions(
     const CXXMemberCallExpr *expr) {
   const auto *object = expr->getImplicitObjectArgument();
-  const auto objectId = doExpr(object);
+  const auto objectId = loadIfAliasVarRef(object);
   const auto type = object->getType();
   const bool isByteAddressBuffer = TypeTranslator::isByteAddressBuffer(type) ||
                                    TypeTranslator::isRWByteAddressBuffer(type);
@@ -2040,8 +2084,7 @@ uint32_t SPIRVEmitter::processRWByteAddressBufferAtomicMethods(
   // void Interlocked*(in UINT dest, in UINT value, out UINT original_value);
 
   const auto *object = expr->getImplicitObjectArgument();
-  // We do not need to load the object since we are using its pointers.
-  const auto objectInfo = doExpr(object);
+  const auto objectInfo = loadIfAliasVarRef(object);
 
   const auto uintType = theBuilder.getUint32Type();
   const uint32_t zero = theBuilder.getConstantUint32(0);
@@ -2400,7 +2443,7 @@ SpirvEvalInfo SPIRVEmitter::processByteAddressBufferLoadStore(
   uint32_t resultId = 0;
   const auto object = expr->getImplicitObjectArgument();
   const auto type = object->getType();
-  const auto objectInfo = doExpr(object);
+  const auto objectInfo = loadIfAliasVarRef(object);
   assert(numWords >= 1 && numWords <= 4);
   if (doStore) {
     assert(typeTranslator.isRWByteAddressBuffer(type));
@@ -2493,7 +2536,7 @@ SPIRVEmitter::processStructuredBufferLoad(const CXXMemberCallExpr *expr) {
   }
 
   const auto *buffer = expr->getImplicitObjectArgument();
-  auto info = doExpr(buffer);
+  auto info = loadIfAliasVarRef(buffer);
 
   const QualType structType =
       hlsl::GetHLSLResourceResultType(buffer->getType());
@@ -3214,9 +3257,11 @@ SPIRVEmitter::doCXXOperatorCallExpr(const CXXOperatorCallExpr *expr) {
   llvm::SmallVector<uint32_t, 4> indices;
   const Expr *baseExpr = collectArrayStructIndices(expr, &indices);
 
-  auto base = doExpr(baseExpr);
+  auto base = loadIfAliasVarRef(baseExpr);
+
   if (indices.empty())
     return base; // For indexing into size-1 vectors and 1xN matrices
+
   // If we are indexing into a rvalue, to use OpAccessChain, we first need
   // to create a local variable to hold the rvalue.
   //
@@ -3380,7 +3425,7 @@ SpirvEvalInfo SPIRVEmitter::doInitListExpr(const InitListExpr *expr) {
 SpirvEvalInfo SPIRVEmitter::doMemberExpr(const MemberExpr *expr) {
   llvm::SmallVector<uint32_t, 4> indices;
   const Expr *base = collectArrayStructIndices(expr, &indices);
-  auto info = doExpr(base);
+  auto info = loadIfAliasVarRef(base);
 
   if (!indices.empty()) {
     // Sometime we are accessing the member of a rvalue, e.g.,

+ 14 - 1
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -106,6 +106,13 @@ private:
   SpirvEvalInfo doMemberExpr(const MemberExpr *expr);
   SpirvEvalInfo doUnaryOperator(const UnaryOperator *expr);
 
+  /// Loads the pointer of the aliased-to-variable if the given expression is a
+  /// DeclRefExpr referencing an alias variable. See DeclResultIdMapper for
+  /// more explanation regarding this.
+  ///
+  /// Note: legalization specific code
+  SpirvEvalInfo loadIfAliasVarRef(const Expr *expr);
+
 private:
   /// Translates the given frontend binary operator into its SPIR-V equivalent
   /// taking consideration of the operand type.
@@ -787,10 +794,16 @@ private:
   /// Whether the translated SPIR-V binary needs legalization.
   ///
   /// The following cases will require legalization:
-  /// * Opaque types (textures, samplers) within structs
+  ///
+  /// 1. Opaque types (textures, samplers) within structs
+  /// 2. Structured buffer assignments
+  ///
+  /// This covers the first case.
   ///
   /// If this is true, SPIRV-Tools legalization passes will be executed after
   /// the translation to legalize the generated SPIR-V binary.
+  ///
+  /// Note: legalization specific code
   bool needsLegalization;
 
   /// Global variables that should be initialized once at the begining of the

+ 17 - 1
tools/clang/lib/SPIRV/SpirvEvalInfo.h

@@ -79,6 +79,9 @@ public:
   /// Handly implicit conversion to test whether the <result-id> is valid.
   operator bool() const { return resultId != 0; }
 
+  inline SpirvEvalInfo &setValTypeId(uint32_t id);
+  uint32_t getValTypeId() const { return valTypeId; }
+
   inline SpirvEvalInfo &setStorageClass(spv::StorageClass sc);
   spv::StorageClass getStorageClass() const { return storageClass; }
 
@@ -96,6 +99,14 @@ public:
 
 private:
   uint32_t resultId;
+  /// The value's <type-id> for this variable.
+  ///
+  /// This field should only be non-zero for original alias variables, which is
+  /// of pointer-to-pointer type. After dereferencing the alias variable, this
+  /// should be set to zero to let CodeGen fall back to normal handling path.
+  ///
+  /// Note: legalization specific code
+  uint32_t valTypeId;
 
   spv::StorageClass storageClass;
   LayoutRule layoutRule;
@@ -106,7 +117,7 @@ private:
 };
 
 SpirvEvalInfo::SpirvEvalInfo(uint32_t id)
-    : resultId(id), storageClass(spv::StorageClass::Function),
+    : resultId(id), valTypeId(0), storageClass(spv::StorageClass::Function),
       layoutRule(LayoutRule::Void), isRValue_(false), isConstant_(false),
       isRelaxedPrecision_(false) {}
 
@@ -121,6 +132,11 @@ SpirvEvalInfo SpirvEvalInfo::substResultId(uint32_t newId) const {
   return info;
 }
 
+SpirvEvalInfo &SpirvEvalInfo::setValTypeId(uint32_t id) {
+  valTypeId = id;
+  return *this;
+}
+
 SpirvEvalInfo &SpirvEvalInfo::setStorageClass(spv::StorageClass sc) {
   storageClass = sc;
   return *this;

+ 11 - 0
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -377,6 +377,17 @@ bool TypeTranslator::isConsumeStructuredBuffer(QualType type) {
   return name == "ConsumeStructuredBuffer";
 }
 
+bool TypeTranslator::isAKindOfStructuredOrByteBuffer(QualType type) {
+  if (const RecordType *recordType = type->getAs<RecordType>()) {
+    StringRef name = recordType->getDecl()->getName();
+    return name == "StructuredBuffer" || name == "RWStructuredBuffer" ||
+           name == "ByteAddressBuffer" || name == "RWByteAddressBuffer" ||
+           name == "AppendStructuredBuffer" ||
+           name == "ConsumeStructuredBuffer";
+  }
+  return false;
+}
+
 bool TypeTranslator::isStructuredBuffer(QualType type) {
   const auto *recordType = type->getAs<RecordType>();
   if (!recordType)

+ 9 - 1
tools/clang/lib/SPIRV/TypeTranslator.h

@@ -80,6 +80,10 @@ public:
   /// \brief Returns true if the given type is the HLSL RWByteAddressBufferType.
   static bool isRWByteAddressBuffer(QualType type);
 
+  /// \brief Returns true if the given type is the HLSL (RW)StructuredBuffer,
+  /// (RW)ByteAddressBuffer, or {Append|Consume}StructuredBuffer.
+  static bool isAKindOfStructuredOrByteBuffer(QualType type);
+
   /// \brief Returns true if the given type is the HLSL Buffer type.
   static bool isBuffer(QualType type);
 
@@ -150,10 +154,14 @@ public:
 
   /// Returns true if the given type will be translated into a SPIR-V image,
   /// sampler or struct containing images or samplers.
+  ///
+  /// Note: legalization specific code
   static bool isOpaqueType(QualType type);
 
   /// Returns true if the given type is a struct type who has an opaque field
   /// (in a recursive away).
+  ///
+  /// Note: legalization specific code
   static bool isOpaqueStructType(QualType tye);
 
   /// \brief Returns a string name for the given type.
@@ -227,4 +235,4 @@ private:
 } // end namespace spirv
 } // end namespace clang
 
-#endif
+#endif

+ 182 - 0
tools/clang/test/CodeGenSPIRV/spirv.legal.sbuffer.methods.hlsl

@@ -0,0 +1,182 @@
+// Run: %dxc -T ps_6_0 -E main
+
+struct S {
+    float4 f;
+};
+
+struct T1 {
+    float3 f;
+};
+
+struct T2 {
+    float2 f;
+};
+
+StructuredBuffer<S>         globalSBuffer;
+RWStructuredBuffer<S>       globalRWSBuffer;
+AppendStructuredBuffer<T1>  globalASBuffer;
+ConsumeStructuredBuffer<T2> globalCSBuffer;
+ByteAddressBuffer           globalBABuffer;
+RWByteAddressBuffer         globalRWBABuffer;
+
+float4 main() : SV_Target {
+// CHECK: %localSBuffer = OpVariable %_ptr_Function__ptr_Uniform_type_StructuredBuffer_S Function
+// CHECK: %localRWSBuffer = OpVariable %_ptr_Function__ptr_Uniform_type_RWStructuredBuffer_S Function
+// CHECK: %localASBuffer = OpVariable %_ptr_Function__ptr_Uniform_type_AppendStructuredBuffer_T1 Function
+// CHECK: %localCSBuffer = OpVariable %_ptr_Function__ptr_Uniform_type_ConsumeStructuredBuffer_T2 Function
+// CHECK: %localBABuffer = OpVariable %_ptr_Function__ptr_Uniform_type_ByteAddressBuffer Function
+// CHECK: %localRWBABuffer = OpVariable %_ptr_Function__ptr_Uniform_type_RWByteAddressBuffer Function
+
+// CHECK: OpStore %localSBuffer %globalSBuffer
+// CHECK: OpStore %localRWSBuffer %globalRWSBuffer
+// CHECK: OpStore %localASBuffer %globalASBuffer
+// CHECK: OpStore %localCSBuffer %globalCSBuffer
+// CHECK: OpStore %localBABuffer %globalBABuffer
+// CHECK: OpStore %localRWBABuffer %globalRWBABuffer
+    StructuredBuffer<S>         localSBuffer    = globalSBuffer;
+    RWStructuredBuffer<S>       localRWSBuffer  = globalRWSBuffer;
+    AppendStructuredBuffer<T1>  localASBuffer   = globalASBuffer;
+    ConsumeStructuredBuffer<T2> localCSBuffer   = globalCSBuffer;
+    ByteAddressBuffer           localBABuffer   = globalBABuffer;
+    RWByteAddressBuffer         localRWBABuffer = globalRWBABuffer;
+
+    T1 t1 = {float3(1., 2., 3.)};
+    T2 t2;
+    uint numStructs, stride, counter;
+    float4 val;
+
+// CHECK:      [[ptr:%\d+]] = OpLoad %_ptr_Uniform_type_StructuredBuffer_S %localSBuffer
+// CHECK-NEXT:     {{%\d+}} = OpArrayLength %uint [[ptr]] 0
+    localSBuffer.GetDimensions(numStructs, stride);
+// CHECK:      [[ptr1:%\d+]] = OpLoad %_ptr_Uniform_type_StructuredBuffer_S %localSBuffer
+// CHECK-NEXT: [[ptr2:%\d+]] = OpAccessChain %_ptr_Uniform_v4float [[ptr1]] %int_0 %int_1 %int_0
+// CHECK-NEXT:      {{%\d+}} = OpLoad %v4float [[ptr2]]
+    val = localSBuffer.Load(1).f;
+// CHECK:      [[ptr1:%\d+]] = OpLoad %_ptr_Uniform_type_StructuredBuffer_S %localSBuffer
+// CHECK-NEXT: [[ptr2:%\d+]] = OpAccessChain %_ptr_Uniform_v4float [[ptr1]] %int_0 %uint_2 %int_0
+// CHECK-NEXT:      {{%\d+}} = OpLoad %v4float [[ptr2]]
+    val = localSBuffer[2].f;
+
+// CHECK:      [[ptr:%\d+]] = OpLoad %_ptr_Uniform_type_RWStructuredBuffer_S %localRWSBuffer
+// CHECK-NEXT:     {{%\d+}} = OpArrayLength %uint [[ptr]] 0
+    localRWSBuffer.GetDimensions(numStructs, stride);
+// CHECK:      [[ptr1:%\d+]] = OpLoad %_ptr_Uniform_type_RWStructuredBuffer_S %localRWSBuffer
+// CHECK-NEXT: [[ptr2:%\d+]] = OpAccessChain %_ptr_Uniform_v4float [[ptr1]] %int_0 %int_3 %int_0
+// CHECK-NEXT:      {{%\d+}} = OpLoad %v4float [[ptr2]]
+    val = localRWSBuffer.Load(3).f;
+// CHECK:      [[ptr1:%\d+]] = OpLoad %_ptr_Uniform_type_RWStructuredBuffer_S %localRWSBuffer
+// CHECK-NEXT: [[ptr2:%\d+]] = OpAccessChain %_ptr_Uniform_v4float [[ptr1]] %int_0 %uint_4 %int_0
+// CHECK-NEXT:                 OpStore [[ptr2]] {{%\d+}}
+    localRWSBuffer[4].f = 42.;
+    // TODO
+    counter = localRWSBuffer.IncrementCounter();
+    // TODO
+    counter = localRWSBuffer.DecrementCounter();
+
+// CHECK:      [[ptr:%\d+]] = OpLoad %_ptr_Uniform_type_AppendStructuredBuffer_T1 %localASBuffer
+// CHECK-NEXT:     {{%\d+}} = OpArrayLength %uint [[ptr]] 0
+    localASBuffer.GetDimensions(numStructs, stride);
+    // TODO
+    counter = localRWSBuffer.DecrementCounter();
+    localASBuffer.Append(t1);
+
+// CHECK:      [[ptr:%\d+]] = OpLoad %_ptr_Uniform_type_ConsumeStructuredBuffer_T2 %localCSBuffer
+// CHECK-NEXT:     {{%\d+}} = OpArrayLength %uint [[ptr]] 0
+    localCSBuffer.GetDimensions(numStructs, stride);
+    // TODO
+    counter = localRWSBuffer.DecrementCounter();
+    t2 = localCSBuffer.Consume();
+
+    uint  byte;
+    uint2 byte2;
+    uint3 byte3;
+    uint4 byte4;
+    uint  dim;
+
+    uint dest, value, compare, origin;
+
+// CHECK:      [[ptr:%\d+]] = OpLoad %_ptr_Uniform_type_ByteAddressBuffer %localBABuffer
+// CHECK-NEXT:     {{%\d+}} = OpArrayLength %uint [[ptr]] 0
+    localBABuffer.GetDimensions(dim);
+// CHECK:      [[ptr1:%\d+]] = OpLoad %_ptr_Uniform_type_ByteAddressBuffer %localBABuffer
+// CHECK:      [[ptr2:%\d+]] = OpAccessChain %_ptr_Uniform_uint [[ptr1]] %uint_0 {{%\d+}}
+// CHECK-NEXT:      {{%\d+}} = OpLoad %uint [[ptr2]]
+    byte  = localBABuffer.Load(4);
+// CHECK:      [[ptr1:%\d+]] = OpLoad %_ptr_Uniform_type_ByteAddressBuffer %localBABuffer
+// CHECK:      [[ptr2:%\d+]] = OpAccessChain %_ptr_Uniform_uint [[ptr1]] %uint_0 {{%\d+}}
+// CHECK-NEXT:      {{%\d+}} = OpLoad %uint [[ptr2]]
+    byte2 = localBABuffer.Load2(5);
+// CHECK:      [[ptr1:%\d+]] = OpLoad %_ptr_Uniform_type_ByteAddressBuffer %localBABuffer
+// CHECK:      [[ptr2:%\d+]] = OpAccessChain %_ptr_Uniform_uint [[ptr1]] %uint_0 {{%\d+}}
+// CHECK-NEXT:      {{%\d+}} = OpLoad %uint [[ptr2]]
+    byte3 = localBABuffer.Load3(6);
+// CHECK:      [[ptr1:%\d+]] = OpLoad %_ptr_Uniform_type_ByteAddressBuffer %localBABuffer
+// CHECK:      [[ptr2:%\d+]] = OpAccessChain %_ptr_Uniform_uint [[ptr1]] %uint_0 {{%\d+}}
+// CHECK-NEXT:      {{%\d+}} = OpLoad %uint [[ptr2]]
+    byte4 = localBABuffer.Load4(7);
+
+// CHECK:      [[ptr:%\d+]] = OpLoad %_ptr_Uniform_type_RWByteAddressBuffer %localRWBABuffer
+// CHECK-NEXT:     {{%\d+}} = OpArrayLength %uint [[ptr]] 0
+    localRWBABuffer.GetDimensions(dim);
+// CHECK:      [[ptr1:%\d+]] = OpLoad %_ptr_Uniform_type_RWByteAddressBuffer %localRWBABuffer
+// CHECK:      [[ptr2:%\d+]] = OpAccessChain %_ptr_Uniform_uint [[ptr1]] %uint_0 {{%\d+}}
+// CHECK-NEXT:      {{%\d+}} = OpLoad %uint [[ptr2]]
+    byte  = localRWBABuffer.Load(8);
+// CHECK:      [[ptr1:%\d+]] = OpLoad %_ptr_Uniform_type_RWByteAddressBuffer %localRWBABuffer
+// CHECK:      [[ptr2:%\d+]] = OpAccessChain %_ptr_Uniform_uint [[ptr1]] %uint_0 {{%\d+}}
+// CHECK-NEXT:      {{%\d+}} = OpLoad %uint [[ptr2]]
+    byte2 = localRWBABuffer.Load2(9);
+// CHECK:      [[ptr1:%\d+]] = OpLoad %_ptr_Uniform_type_RWByteAddressBuffer %localRWBABuffer
+// CHECK:      [[ptr2:%\d+]] = OpAccessChain %_ptr_Uniform_uint [[ptr1]] %uint_0 {{%\d+}}
+// CHECK-NEXT:      {{%\d+}} = OpLoad %uint [[ptr2]]
+    byte3 = localRWBABuffer.Load3(10);
+// CHECK:      [[ptr1:%\d+]] = OpLoad %_ptr_Uniform_type_RWByteAddressBuffer %localRWBABuffer
+// CHECK:      [[ptr2:%\d+]] = OpAccessChain %_ptr_Uniform_uint [[ptr1]] %uint_0 {{%\d+}}
+// CHECK-NEXT:      {{%\d+}} = OpLoad %uint [[ptr2]]
+    byte4 = localRWBABuffer.Load4(11);
+// CHECK:      [[ptr1:%\d+]] = OpLoad %_ptr_Uniform_type_RWByteAddressBuffer %localRWBABuffer
+// CHECK:      [[ptr2:%\d+]] = OpAccessChain %_ptr_Uniform_uint [[ptr1]] %uint_0 {{%\d+}}
+// CHECK-NEXT:                 OpStore [[ptr2]] {{%\d+}}
+    localRWBABuffer.Store(12, byte);
+// CHECK:      [[ptr1:%\d+]] = OpLoad %_ptr_Uniform_type_RWByteAddressBuffer %localRWBABuffer
+// CHECK:      [[ptr2:%\d+]] = OpAccessChain %_ptr_Uniform_uint [[ptr1]] %uint_0 {{%\d+}}
+// CHECK-NEXT:                 OpStore [[ptr2]] {{%\d+}}
+    localRWBABuffer.Store(13, byte2);
+// CHECK:      [[ptr1:%\d+]] = OpLoad %_ptr_Uniform_type_RWByteAddressBuffer %localRWBABuffer
+// CHECK:      [[ptr2:%\d+]] = OpAccessChain %_ptr_Uniform_uint [[ptr1]] %uint_0 {{%\d+}}
+// CHECK-NEXT:                 OpStore [[ptr2]] {{%\d+}}
+    localRWBABuffer.Store(14, byte3);
+// CHECK:      [[ptr1:%\d+]] = OpLoad %_ptr_Uniform_type_RWByteAddressBuffer %localRWBABuffer
+// CHECK:      [[ptr2:%\d+]] = OpAccessChain %_ptr_Uniform_uint [[ptr1]] %uint_0 {{%\d+}}
+// CHECK-NEXT:                 OpStore [[ptr2]] {{%\d+}}
+    localRWBABuffer.Store(15, byte4);
+// CHECK:      [[ptr1:%\d+]] = OpLoad %_ptr_Uniform_type_RWByteAddressBuffer %localRWBABuffer
+// CHECK:      [[ptr2:%\d+]] = OpAccessChain %_ptr_Uniform_uint [[ptr1]] %uint_0 {{%\d+}}
+    localRWBABuffer.InterlockedAdd(dest, value, origin);
+// CHECK:      [[ptr1:%\d+]] = OpLoad %_ptr_Uniform_type_RWByteAddressBuffer %localRWBABuffer
+// CHECK:      [[ptr2:%\d+]] = OpAccessChain %_ptr_Uniform_uint [[ptr1]] %uint_0 {{%\d+}}
+    localRWBABuffer.InterlockedAnd(dest, value, origin);
+// CHECK:      [[ptr1:%\d+]] = OpLoad %_ptr_Uniform_type_RWByteAddressBuffer %localRWBABuffer
+// CHECK:      [[ptr2:%\d+]] = OpAccessChain %_ptr_Uniform_uint [[ptr1]] %uint_0 {{%\d+}}
+    localRWBABuffer.InterlockedCompareExchange(dest, compare, value, origin);
+// CHECK:      [[ptr1:%\d+]] = OpLoad %_ptr_Uniform_type_RWByteAddressBuffer %localRWBABuffer
+// CHECK:      [[ptr2:%\d+]] = OpAccessChain %_ptr_Uniform_uint [[ptr1]] %uint_0 {{%\d+}}
+    localRWBABuffer.InterlockedCompareStore(dest, compare, value);
+// CHECK:      [[ptr1:%\d+]] = OpLoad %_ptr_Uniform_type_RWByteAddressBuffer %localRWBABuffer
+// CHECK:      [[ptr2:%\d+]] = OpAccessChain %_ptr_Uniform_uint [[ptr1]] %uint_0 {{%\d+}}
+    localRWBABuffer.InterlockedExchange(dest, value, origin);
+// CHECK:      [[ptr1:%\d+]] = OpLoad %_ptr_Uniform_type_RWByteAddressBuffer %localRWBABuffer
+// CHECK:      [[ptr2:%\d+]] = OpAccessChain %_ptr_Uniform_uint [[ptr1]] %uint_0 {{%\d+}}
+    localRWBABuffer.InterlockedMax(dest, value, origin);
+// CHECK:      [[ptr1:%\d+]] = OpLoad %_ptr_Uniform_type_RWByteAddressBuffer %localRWBABuffer
+// CHECK:      [[ptr2:%\d+]] = OpAccessChain %_ptr_Uniform_uint [[ptr1]] %uint_0 {{%\d+}}
+    localRWBABuffer.InterlockedMin(dest, value, origin);
+// CHECK:      [[ptr1:%\d+]] = OpLoad %_ptr_Uniform_type_RWByteAddressBuffer %localRWBABuffer
+// CHECK:      [[ptr2:%\d+]] = OpAccessChain %_ptr_Uniform_uint [[ptr1]] %uint_0 {{%\d+}}
+    localRWBABuffer.InterlockedOr(dest, value, origin);
+// CHECK:      [[ptr1:%\d+]] = OpLoad %_ptr_Uniform_type_RWByteAddressBuffer %localRWBABuffer
+// CHECK:      [[ptr2:%\d+]] = OpAccessChain %_ptr_Uniform_uint [[ptr1]] %uint_0 {{%\d+}}
+    localRWBABuffer.InterlockedXor(dest, value, origin);
+
+    return val;
+}

+ 86 - 0
tools/clang/test/CodeGenSPIRV/spirv.legal.sbuffer.usage.hlsl

@@ -0,0 +1,86 @@
+// Run: %dxc -T ps_6_0 -E main
+
+struct S {
+    float4 f;
+};
+
+// Signature for returnRWSBuffer(). Both the function parameter and return gain
+// an extra level of pointer.
+// CHECK: [[retRWSBSig:%\d+]] = OpTypeFunction %_ptr_Uniform_type_RWStructuredBuffer_S %_ptr_Function__ptr_Uniform_type_RWStructuredBuffer_S
+RWStructuredBuffer<S> returnRWSBuffer(RWStructuredBuffer<S> paramRWSBuffer);
+
+float4 useAsStaticRWSBuffer();
+
+// CHECK:  %globalRWSBuffer = OpVariable %_ptr_Uniform_type_RWStructuredBuffer_S Uniform
+       RWStructuredBuffer<S> globalRWSBuffer;
+// Static global variables gain an extra level of pointer.
+// CHECK: %staticgRWSBuffer = OpVariable %_ptr_Private__ptr_Uniform_type_RWStructuredBuffer_S Private
+static RWStructuredBuffer<S> staticgRWSBuffer = globalRWSBuffer;
+
+// CHECK: %globalv4f32RWSBuffer = OpVariable %_ptr_Uniform_type_RWStructuredBuffer_v4float Uniform
+  RWStructuredBuffer<float4> globalv4f32RWSBuffer;
+
+// Static local variables gain an extra level of pointer.
+// CHECK: %staticRWSBuffer = OpVariable %_ptr_Private__ptr_Uniform_type_RWStructuredBuffer_S Private
+
+// CHECK: %src_main = OpFunction
+float4 main(in float4 pos : SV_Position) : SV_Target
+{
+// Local variables gain an extra level of pointer.
+// CHECK:           %localRWSBuffer = OpVariable %_ptr_Function__ptr_Uniform_type_RWStructuredBuffer_S Function
+
+// Temporary variables for function calls gain an extra level of pointer.
+// CHECK: %param_var_paramRWSBuffer = OpVariable %_ptr_Function__ptr_Uniform_type_RWStructuredBuffer_S Function
+
+// CHECK: %localv4f32RWSBuffer = OpVariable %_ptr_Function__ptr_Uniform_type_RWStructuredBuffer_v4float Function
+
+// CHECK:      [[ptr1:%\d+]] = OpLoad %_ptr_Uniform_type_RWStructuredBuffer_S %staticgRWSBuffer
+// CHECK-NEXT: [[ptr2:%\d+]] = OpAccessChain %_ptr_Uniform_v4float [[ptr1]] %int_0 %uint_0 %int_0
+// CHECK-NEXT:      {{%\d+}} = OpLoad %v4float [[ptr2]]
+    float4 val = staticgRWSBuffer[0].f + useAsStaticRWSBuffer();
+
+// Directly storing the pointer to %param_var_paramRWSBuffer
+// CHECK:      OpStore %param_var_paramRWSBuffer %globalRWSBuffer
+// Function calls have matching signatures.
+// CHECK-NEXT: [[ptr:%\d+]] = OpFunctionCall %_ptr_Uniform_type_RWStructuredBuffer_S %returnRWSBuffer %param_var_paramRWSBuffer
+// CHECK-NEXT: OpStore %localRWSBuffer [[ptr]]
+    RWStructuredBuffer<S> localRWSBuffer = returnRWSBuffer(globalRWSBuffer);
+
+// CHECK:      [[ptr:%\d+]] = OpLoad %_ptr_Uniform_type_RWStructuredBuffer_S %staticgRWSBuffer
+// CHECK-NEXT: OpStore %localRWSBuffer [[ptr]]
+    localRWSBuffer = staticgRWSBuffer;
+
+// CHECK:      {{%\d+}} = OpAccessChain %_ptr_Uniform_v4float %globalRWSBuffer %int_0 %uint_1 %int_0
+    globalRWSBuffer[1].f = 4.2;
+
+// CHECK:      OpStore %localv4f32RWSBuffer %globalv4f32RWSBuffer
+    RWStructuredBuffer<float4> localv4f32RWSBuffer = globalv4f32RWSBuffer;
+
+    return val +
+// CHECK:      [[ptr1:%\d+]] = OpLoad %_ptr_Uniform_type_RWStructuredBuffer_S %localRWSBuffer
+// CHECK-NEXT: [[ptr2:%\d+]] = OpAccessChain %_ptr_Uniform_v4float [[ptr1]] %int_0 %uint_2 %int_0
+// CHECK-NEXT:      {{%\d+}} = OpLoad %v4float [[ptr2]]
+        localRWSBuffer[2].f +
+// CHECK:      [[ptr1:%\d+]] = OpLoad %_ptr_Uniform_type_RWStructuredBuffer_v4float %localv4f32RWSBuffer
+// CHECK-NEXT: [[ptr2:%\d+]] = OpAccessChain %_ptr_Uniform_v4float [[ptr1]] %int_0 %uint_3
+// CHECK-NEXT:      {{%\d+}} = OpLoad %v4float [[ptr2]]
+        localv4f32RWSBuffer[3];
+}
+
+// CHECK: %useAsStaticRWSBuffer = OpFunction
+float4 useAsStaticRWSBuffer() {
+// Directly storing the pointer to %staticRWSBuffer
+// CHECK: OpStore %staticRWSBuffer %globalRWSBuffer
+    static RWStructuredBuffer<S> staticRWSBuffer = globalRWSBuffer;
+    staticRWSBuffer[0].f = 30;
+    return staticRWSBuffer[0].f;
+}
+
+// CHECK: %returnRWSBuffer = OpFunction %_ptr_Uniform_type_RWStructuredBuffer_S None [[retRWSBSig]]
+// Function parameters gain an extra level of pointer.
+// CHECK:  %paramRWSBuffer = OpFunctionParameter %_ptr_Function__ptr_Uniform_type_RWStructuredBuffer_S
+RWStructuredBuffer<S> returnRWSBuffer(RWStructuredBuffer<S> paramRWSBuffer) {
+// CHECK:     [[ptr:%\d+]] = OpLoad %_ptr_Uniform_type_RWStructuredBuffer_S %paramRWSBuffer
+// CHECK-NEXT:               OpReturnValue [[ptr]]
+    return paramRWSBuffer;
+}

+ 10 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -968,6 +968,16 @@ TEST_F(FileTest, SpirvLegalizationOpaqueStruct) {
               // The generated SPIR-V needs legalization.
               /*runValidation=*/false);
 }
+TEST_F(FileTest, SpirvLegalizationStructuredBufferUsage) {
+  runFileTest("spirv.legal.sbuffer.usage.hlsl", Expect::Success,
+              // The generated SPIR-V needs legalization.
+              /*runValidation=*/false);
+}
+TEST_F(FileTest, SpirvLegalizationStructuredBufferMethods) {
+  runFileTest("spirv.legal.sbuffer.methods.hlsl", Expect::Success,
+              // The generated SPIR-V needs legalization.
+              /*runValidation=*/false);
+}
 TEST_F(FileTest, SpirvLegalizationConstantBuffer) {
   runFileTest("spirv.legal.cbuffer.hlsl", Expect::Success,
               // The generated SPIR-V needs legalization.