2
0
Эх сурвалжийг харах

[spirv] Support (RW)ByteAddrressBuffer and load/store methods (#581)

Ehsan 8 жил өмнө
parent
commit
4cbc45cf27

+ 21 - 6
tools/clang/include/clang/SPIRV/Constant.h

@@ -17,6 +17,7 @@
 #include "clang/SPIRV/Decoration.h"
 #include "clang/SPIRV/Decoration.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/Optional.h"
 #include "llvm/ADT/Optional.h"
+#include "llvm/ADT/SetVector.h"
 
 
 namespace clang {
 namespace clang {
 namespace spirv {
 namespace spirv {
@@ -37,12 +38,12 @@ class SPIRVContext;
 /// context).
 /// context).
 class Constant {
 class Constant {
 public:
 public:
-  using DecorationSet = std::set<const Decoration *>;
+  using DecorationSet = llvm::ArrayRef<const Decoration *>;
 
 
   spv::Op getOpcode() const { return opcode; }
   spv::Op getOpcode() const { return opcode; }
   uint32_t getTypeId() const { return typeId; }
   uint32_t getTypeId() const { return typeId; }
   const std::vector<uint32_t> &getArgs() const { return args; }
   const std::vector<uint32_t> &getArgs() const { return args; }
-  const DecorationSet &getDecorations() const { return decorations; }
+  const auto &getDecorations() const { return decorations; }
   bool hasDecoration(const Decoration *) const;
   bool hasDecoration(const Decoration *) const;
 
 
   // OpConstantTrue and OpConstantFalse are boolean.
   // OpConstantTrue and OpConstantFalse are boolean.
@@ -95,8 +96,17 @@ public:
                                           DecorationSet dec = {});
                                           DecorationSet dec = {});
 
 
   bool operator==(const Constant &other) const {
   bool operator==(const Constant &other) const {
-    return opcode == other.opcode && typeId == other.typeId &&
-           args == other.args && decorations == other.decorations;
+    if (opcode == other.opcode && typeId == other.typeId &&
+        args == other.args && decorations.size() == other.decorations.size()) {
+      // If two constants have the same decorations, but in different order,
+      // they are in fact the same.
+      for (const Decoration *dec : decorations) {
+        if (other.decorations.count(dec) == 0)
+          return false;
+      }
+      return true;
+    }
+    return false;
   }
   }
 
 
   // \brief Construct the SPIR-V words for this constant with the given
   // \brief Construct the SPIR-V words for this constant with the given
@@ -106,7 +116,7 @@ public:
 private:
 private:
   /// \brief Private constructor.
   /// \brief Private constructor.
   Constant(spv::Op, uint32_t type, llvm::ArrayRef<uint32_t> arg = {},
   Constant(spv::Op, uint32_t type, llvm::ArrayRef<uint32_t> arg = {},
-           std::set<const Decoration *> dec = {});
+           DecorationSet dec = {});
 
 
   /// \brief Returns the unique constant pointer within the given context.
   /// \brief Returns the unique constant pointer within the given context.
   static const Constant *getUniqueConstant(SPIRVContext &, const Constant &);
   static const Constant *getUniqueConstant(SPIRVContext &, const Constant &);
@@ -115,7 +125,12 @@ private:
   spv::Op opcode;             ///< OpCode of the constant
   spv::Op opcode;             ///< OpCode of the constant
   uint32_t typeId;            ///< <result-id> of the type of the constant
   uint32_t typeId;            ///< <result-id> of the type of the constant
   std::vector<uint32_t> args; ///< Arguments needed to define the constant
   std::vector<uint32_t> args; ///< Arguments needed to define the constant
-  DecorationSet decorations;  ///< Decorations applied to the constant
+
+  /// The decorations that are applied to a constant.
+  /// Note: we use a SetVector because:
+  /// a) Duplicate decorations should be removed.
+  /// b) Order of insertion matters for deterministic SPIR-V emitting
+  llvm::SetVector<const Decoration *> decorations;
 };
 };
 
 
 } // end namespace spirv
 } // end namespace spirv

+ 1 - 0
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -292,6 +292,7 @@ public:
   uint32_t getImageType(uint32_t sampledType, spv::Dim, bool isArray);
   uint32_t getImageType(uint32_t sampledType, spv::Dim, bool isArray);
   uint32_t getSamplerType();
   uint32_t getSamplerType();
   uint32_t getSampledImageType(uint32_t imageType);
   uint32_t getSampledImageType(uint32_t imageType);
+  uint32_t getByteAddressBufferType(bool isRW);
 
 
   // === Constant ===
   // === Constant ===
   uint32_t getConstantBool(bool value);
   uint32_t getConstantBool(bool value);

+ 16 - 4
tools/clang/include/clang/SPIRV/Structure.h

@@ -303,8 +303,13 @@ public:
   // TODO: source code debug information
   // TODO: source code debug information
   inline void addDebugName(uint32_t targetId, llvm::StringRef name,
   inline void addDebugName(uint32_t targetId, llvm::StringRef name,
                            llvm::Optional<uint32_t> memberIndex = llvm::None);
                            llvm::Optional<uint32_t> memberIndex = llvm::None);
-  inline void addDecoration(const Decoration &decoration, uint32_t targetId);
+  /// \brief Adds a decoration to the given target.
+  inline void addDecoration(const Decoration *decoration, uint32_t targetId);
+  /// \brief Adds a type to the module. Also adds the type's decorations to the
+  /// set of decorations of the module.
   inline void addType(const Type *type, uint32_t resultId);
   inline void addType(const Type *type, uint32_t resultId);
+  /// \brief Adds a constant to the module. Also adds the constant's decorations
+  /// to the set of decorations of the module.
   inline void addConstant(const Constant *constant, uint32_t resultId);
   inline void addConstant(const Constant *constant, uint32_t resultId);
   inline void addVariable(Instruction &&);
   inline void addVariable(Instruction &&);
   inline void addFunction(std::unique_ptr<Function>);
   inline void addFunction(std::unique_ptr<Function>);
@@ -340,7 +345,8 @@ private:
   std::vector<Instruction> executionModes;
   std::vector<Instruction> executionModes;
   // TODO: source code debug information
   // TODO: source code debug information
   std::vector<DebugName> debugNames;
   std::vector<DebugName> debugNames;
-  std::vector<DecorationIdPair> decorations;
+  llvm::SetVector<std::pair<uint32_t, const Decoration *>> decorations;
+
   // Note that types and constants are interdependent; Types like arrays have
   // Note that types and constants are interdependent; Types like arrays have
   // <result-id>s for constants in their definition, and constants all have
   // <result-id>s for constants in their definition, and constants all have
   // their corresponding types. We store types and constants separately, but
   // their corresponding types. We store types and constants separately, but
@@ -487,17 +493,23 @@ void SPIRVModule::addDebugName(uint32_t targetId, llvm::StringRef name,
   }
   }
 }
 }
 
 
-void SPIRVModule::addDecoration(const Decoration &decoration,
+void SPIRVModule::addDecoration(const Decoration *decoration,
                                 uint32_t targetId) {
                                 uint32_t targetId) {
-  decorations.emplace_back(decoration, targetId);
+  decorations.insert(std::make_pair(targetId, decoration));
 }
 }
 
 
 void SPIRVModule::addType(const Type *type, uint32_t resultId) {
 void SPIRVModule::addType(const Type *type, uint32_t resultId) {
   types.insert(std::make_pair(type, resultId));
   types.insert(std::make_pair(type, resultId));
+  for (const Decoration *d : type->getDecorations()) {
+    addDecoration(d, resultId);
+  }
 }
 }
 
 
 void SPIRVModule::addConstant(const Constant *constant, uint32_t resultId) {
 void SPIRVModule::addConstant(const Constant *constant, uint32_t resultId) {
   constants.insert(std::make_pair(constant, resultId));
   constants.insert(std::make_pair(constant, resultId));
+  for (const Decoration *d : constant->getDecorations()) {
+    addDecoration(d, resultId);
+  }
 };
 };
 
 
 void SPIRVModule::addVariable(Instruction &&var) {
 void SPIRVModule::addVariable(Instruction &&var) {

+ 21 - 7
tools/clang/include/clang/SPIRV/Type.h

@@ -17,6 +17,7 @@
 #include "clang/SPIRV/Decoration.h"
 #include "clang/SPIRV/Decoration.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/Optional.h"
 #include "llvm/ADT/Optional.h"
+#include "llvm/ADT/SetVector.h"
 
 
 namespace clang {
 namespace clang {
 namespace spirv {
 namespace spirv {
@@ -36,11 +37,11 @@ class SPIRVContext;
 /// context).
 /// context).
 class Type {
 class Type {
 public:
 public:
-  using DecorationSet = std::set<const Decoration *>;
+  using DecorationSet = llvm::ArrayRef<const Decoration *>;
 
 
   spv::Op getOpcode() const { return opcode; }
   spv::Op getOpcode() const { return opcode; }
   const std::vector<uint32_t> &getArgs() const { return args; }
   const std::vector<uint32_t> &getArgs() const { return args; }
-  const DecorationSet &getDecorations() const { return decorations; }
+  const auto &getDecorations() const { return decorations; }
   bool hasDecoration(const Decoration *) const;
   bool hasDecoration(const Decoration *) const;
 
 
   bool isBooleanType() const;
   bool isBooleanType() const;
@@ -111,8 +112,17 @@ public:
                                        spv::StorageClass storage_class,
                                        spv::StorageClass storage_class,
                                        DecorationSet decs = {});
                                        DecorationSet decs = {});
   bool operator==(const Type &other) const {
   bool operator==(const Type &other) const {
-    return opcode == other.opcode && args == other.args &&
-           decorations == other.decorations;
+    if (opcode == other.opcode && args == other.args &&
+        decorations.size() == other.decorations.size()) {
+      // If two types have the same decorations, but in different order,
+      // they are in fact the same type.
+      for (const Decoration* dec : decorations) {
+        if (other.decorations.count(dec) == 0)
+          return false;
+      }
+      return true;
+    }
+    return false;
   }
   }
 
 
   // \brief Construct the SPIR-V words for this type with the given <result-id>.
   // \brief Construct the SPIR-V words for this type with the given <result-id>.
@@ -120,8 +130,7 @@ public:
 
 
 private:
 private:
   /// \brief Private constructor.
   /// \brief Private constructor.
-  Type(spv::Op op, std::vector<uint32_t> arg = {},
-       std::set<const Decoration *> dec = {});
+  Type(spv::Op op, std::vector<uint32_t> arg = {}, DecorationSet dec = {});
 
 
   /// \brief Returns the unique Type pointer within the given context.
   /// \brief Returns the unique Type pointer within the given context.
   static const Type *getUniqueType(SPIRVContext &, const Type &);
   static const Type *getUniqueType(SPIRVContext &, const Type &);
@@ -129,7 +138,12 @@ private:
 private:
 private:
   spv::Op opcode;             ///< OpCode of the Type defined in SPIR-V Spec
   spv::Op opcode;             ///< OpCode of the Type defined in SPIR-V Spec
   std::vector<uint32_t> args; ///< Arguments needed to define the type
   std::vector<uint32_t> args; ///< Arguments needed to define the type
-  DecorationSet decorations;  ///< decorations applied to the type
+
+  /// The decorations that are applied to a type.
+  /// Note: we use a SetVector because:
+  /// a) Duplicate decorations should be removed.
+  /// b) Order of insertion matters for deterministic SPIR-V emitting
+  llvm::SetVector<const Decoration *> decorations;
 };
 };
 
 
 } // end namespace spirv
 } // end namespace spirv

+ 5 - 3
tools/clang/lib/SPIRV/Constant.cpp

@@ -15,8 +15,10 @@ namespace clang {
 namespace spirv {
 namespace spirv {
 
 
 Constant::Constant(spv::Op op, uint32_t type, llvm::ArrayRef<uint32_t> arg,
 Constant::Constant(spv::Op op, uint32_t type, llvm::ArrayRef<uint32_t> arg,
-                   std::set<const Decoration *> decs)
-    : opcode(op), typeId(type), args(arg), decorations(decs) {}
+                   DecorationSet decs)
+    : opcode(op), typeId(type), args(arg) {
+  decorations = llvm::SetVector<const Decoration *>(decs.begin(), decs.end());
+}
 
 
 const Constant *Constant::getUniqueConstant(SPIRVContext &context,
 const Constant *Constant::getUniqueConstant(SPIRVContext &context,
                                             const Constant &c) {
                                             const Constant &c) {
@@ -121,7 +123,7 @@ Constant::getSpecComposite(SPIRVContext &ctx, uint32_t type_id,
 }
 }
 
 
 bool Constant::hasDecoration(const Decoration *d) const {
 bool Constant::hasDecoration(const Decoration *d) const {
-  return decorations.find(d) != decorations.end();
+  return decorations.count(d);
 }
 }
 
 
 bool Constant::isBoolean() const {
 bool Constant::isBoolean() const {

+ 12 - 4
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -100,10 +100,18 @@ uint32_t DeclResultIdMapper::createFileVar(uint32_t varType, const VarDecl *var,
 
 
 uint32_t DeclResultIdMapper::createExternVar(uint32_t varType,
 uint32_t DeclResultIdMapper::createExternVar(uint32_t varType,
                                              const VarDecl *var) {
                                              const VarDecl *var) {
-  // TODO: storage class can also be Uniform
-  const uint32_t id = theBuilder.addModuleVar(
-      varType, spv::StorageClass::UniformConstant, var->getName(), llvm::None);
-  astDecls[var] = {id, spv::StorageClass::UniformConstant};
+  auto storageClass = spv::StorageClass::UniformConstant;
+
+  // TODO: Figure out other cases where the storage class should be Uniform.
+  if (auto *t = var->getType()->getAs<RecordType>()) {
+    const llvm::StringRef typeName = t->getDecl()->getName();
+    if (typeName == "ByteAddressBuffer" || typeName == "RWByteAddressBuffer")
+      storageClass = spv::StorageClass::Uniform;
+  }
+
+  const uint32_t id = theBuilder.addModuleVar(varType, storageClass,
+                                              var->getName(), llvm::None);
+  astDecls[var] = {id, storageClass};
   resourceVars.emplace_back(id, getResourceBinding(var),
   resourceVars.emplace_back(id, getResourceBinding(var),
                             var->getAttr<VKBindingAttr>());
                             var->getAttr<VKBindingAttr>());
 
 

+ 33 - 5
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -466,7 +466,7 @@ uint32_t ModuleBuilder::addStageBuiltinVar(uint32_t type, spv::StorageClass sc,
 
 
   // Decorate with the specified Builtin
   // Decorate with the specified Builtin
   const Decoration *d = Decoration::getBuiltIn(theContext, builtin);
   const Decoration *d = Decoration::getBuiltIn(theContext, builtin);
-  theModule.addDecoration(*d, varId);
+  theModule.addDecoration(d, varId);
 
 
   return varId;
   return varId;
 }
 }
@@ -488,16 +488,16 @@ uint32_t ModuleBuilder::addModuleVar(uint32_t type, spv::StorageClass sc,
 void ModuleBuilder::decorateDSetBinding(uint32_t targetId, uint32_t setNumber,
 void ModuleBuilder::decorateDSetBinding(uint32_t targetId, uint32_t setNumber,
                                         uint32_t bindingNumber) {
                                         uint32_t bindingNumber) {
   const auto *d = Decoration::getDescriptorSet(theContext, setNumber);
   const auto *d = Decoration::getDescriptorSet(theContext, setNumber);
-  theModule.addDecoration(*d, targetId);
+  theModule.addDecoration(d, targetId);
 
 
   d = Decoration::getBinding(theContext, bindingNumber);
   d = Decoration::getBinding(theContext, bindingNumber);
-  theModule.addDecoration(*d, targetId);
+  theModule.addDecoration(d, targetId);
 }
 }
 
 
 void ModuleBuilder::decorateLocation(uint32_t targetId, uint32_t location) {
 void ModuleBuilder::decorateLocation(uint32_t targetId, uint32_t location) {
   const Decoration *d =
   const Decoration *d =
       Decoration::getLocation(theContext, location, llvm::None);
       Decoration::getLocation(theContext, location, llvm::None);
-  theModule.addDecoration(*d, targetId);
+  theModule.addDecoration(d, targetId);
 }
 }
 
 
 void ModuleBuilder::decorate(uint32_t targetId, spv::Decoration decoration) {
 void ModuleBuilder::decorate(uint32_t targetId, spv::Decoration decoration) {
@@ -518,7 +518,7 @@ void ModuleBuilder::decorate(uint32_t targetId, spv::Decoration decoration) {
   }
   }
 
 
   assert(d && "unimplemented decoration");
   assert(d && "unimplemented decoration");
-  theModule.addDecoration(*d, targetId);
+  theModule.addDecoration(d, targetId);
 }
 }
 
 
 #define IMPL_GET_PRIMITIVE_TYPE(ty)                                            \
 #define IMPL_GET_PRIMITIVE_TYPE(ty)                                            \
@@ -665,6 +665,7 @@ uint32_t ModuleBuilder::getSamplerType() {
   return typeId;
   return typeId;
 }
 }
 
 
+
 uint32_t ModuleBuilder::getSampledImageType(uint32_t imageType) {
 uint32_t ModuleBuilder::getSampledImageType(uint32_t imageType) {
   const Type *type = Type::getSampledImage(theContext, imageType);
   const Type *type = Type::getSampledImage(theContext, imageType);
   const uint32_t typeId = theContext.getResultIdForType(type);
   const uint32_t typeId = theContext.getResultIdForType(type);
@@ -673,6 +674,33 @@ uint32_t ModuleBuilder::getSampledImageType(uint32_t imageType) {
   return typeId;
   return typeId;
 }
 }
 
 
+uint32_t ModuleBuilder::getByteAddressBufferType(bool isRW) {
+  // Create a uint RuntimeArray with Array Stride of 4.
+  const uint32_t uintType = getUint32Type();
+  const auto *arrStride4 = Decoration::getArrayStride(theContext, 4u);
+  const Type *raType =
+      Type::getRuntimeArray(theContext, uintType, {arrStride4});
+  const uint32_t raTypeId = theContext.getResultIdForType(raType);
+  theModule.addType(raType, raTypeId);
+
+  // Create a struct containing the runtime array as its only member.
+  // The struct must also be decorated as BufferBlock. The offset decoration
+  // should also be applied to the first (only) member. NonWritable decoration
+  // should also be applied to the first member if isRW is true.
+  llvm::SmallVector<const Decoration*, 3> typeDecs;
+  typeDecs.push_back(Decoration::getBufferBlock(theContext));
+  typeDecs.push_back(Decoration::getOffset(theContext, 0, 0));
+  if (!isRW)
+    typeDecs.push_back(Decoration::getNonWritable(theContext, 0));
+
+  const Type *type = Type::getStruct(theContext, {raTypeId}, typeDecs);
+  const uint32_t typeId = theContext.getResultIdForType(type);
+  theModule.addType(type, typeId);
+  theModule.addDebugName(typeId, isRW ? "type.RWByteAddressBuffer"
+                                      : "type.ByteAddressBuffer");
+  return typeId;
+}
+
 uint32_t ModuleBuilder::getConstantBool(bool value) {
 uint32_t ModuleBuilder::getConstantBool(bool value) {
   const uint32_t typeId = getBoolType();
   const uint32_t typeId = getBoolType();
   const Constant *constant = value ? Constant::getTrue(theContext, typeId)
   const Constant *constant = value ? Constant::getTrue(theContext, typeId)

+ 117 - 2
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -1284,6 +1284,94 @@ uint32_t SPIRVEmitter::doConditionalOperator(const ConditionalOperator *expr) {
   return theBuilder.createSelect(type, condition, trueBranch, falseBranch);
   return theBuilder.createSelect(type, condition, trueBranch, falseBranch);
 }
 }
 
 
+uint32_t SPIRVEmitter::processByteAddressBufferLoadStore(
+    const CXXMemberCallExpr *expr, uint32_t numWords, bool doStore) {
+  uint32_t resultId = 0;
+  const auto object = expr->getImplicitObjectArgument();
+  const auto type = object->getType();
+  const uint32_t objectId = doExpr(object);
+  assert(numWords >= 1 && numWords <= 4);
+  if (doStore) {
+    assert(typeTranslator.isRWByteAddressBuffer(type));
+    assert(expr->getNumArgs() == 2);
+  } else {
+    assert(typeTranslator.isRWByteAddressBuffer(type) ||
+           typeTranslator.isByteAddressBuffer(type));
+    if (expr->getNumArgs() == 2) {
+
+      emitError("Load(in Address, out Status) has not been implemented for "
+                "(RW)ByteAddressBuffer yet.");
+      return 0;
+    }
+  }
+  const Expr *addressExpr = expr->getArg(0);
+  const uint32_t byteAddress = doExpr(addressExpr);
+  const uint32_t addressTypeId =
+      typeTranslator.translateType(addressExpr->getType());
+
+  // Do a OpShiftRightLogical by 2 (divide by 4 to get aligned memory
+  // access). The AST always casts the address to unsinged integer, so shift
+  // by unsinged integer 2.
+  const uint32_t constUint2 = theBuilder.getConstantUint32(2);
+  const uint32_t address = theBuilder.createBinaryOp(
+      spv::Op::OpShiftRightLogical, addressTypeId, byteAddress, constUint2);
+
+  // Perform access chain into the RWByteAddressBuffer.
+  // First index must be zero (member 0 of the struct is a
+  // runtimeArray). The second index passed to OpAccessChain should be
+  // the address.
+  const uint32_t uintTypeId = theBuilder.getUint32Type();
+  const uint32_t ptrType = theBuilder.getPointerType(
+      uintTypeId, declIdMapper.resolveStorageClass(object));
+  const uint32_t constUint0 = theBuilder.getConstantUint32(0);
+
+  if (doStore) {
+    const uint32_t valuesId = doExpr(expr->getArg(1));
+    uint32_t curStoreAddress = address;
+    for (uint32_t wordCounter = 0; wordCounter < numWords; ++wordCounter) {
+      // Extract a 32-bit word from the input.
+      const uint32_t curValue = numWords == 1
+                                    ? valuesId
+                                    : theBuilder.createCompositeExtract(
+                                          uintTypeId, valuesId, {wordCounter});
+
+      // Update the output address if necessary.
+      if (wordCounter > 0) {
+        const uint32_t offset = theBuilder.getConstantUint32(wordCounter);
+        curStoreAddress = theBuilder.createBinaryOp(
+            spv::Op::OpIAdd, addressTypeId, address, offset);
+      }
+
+      // Store the word to the right address at the output.
+      const uint32_t storePtr = theBuilder.createAccessChain(
+          ptrType, objectId, {constUint0, curStoreAddress});
+      theBuilder.createStore(storePtr, curValue);
+    }
+  } else {
+    uint32_t loadPtr =
+        theBuilder.createAccessChain(ptrType, objectId, {constUint0, address});
+    resultId = theBuilder.createLoad(uintTypeId, loadPtr);
+    if (numWords > 1) {
+      // Load word 2, 3, and 4 where necessary. Use OpCompositeConstruct to
+      // return a vector result.
+      llvm::SmallVector<uint32_t, 4> values;
+      values.push_back(resultId);
+      for (uint32_t wordCounter = 2; wordCounter <= numWords; ++wordCounter) {
+        const uint32_t offset = theBuilder.getConstantUint32(wordCounter - 1);
+        const uint32_t newAddress = theBuilder.createBinaryOp(
+            spv::Op::OpIAdd, addressTypeId, address, offset);
+        loadPtr = theBuilder.createAccessChain(ptrType, objectId,
+                                               {constUint0, newAddress});
+        values.push_back(theBuilder.createLoad(uintTypeId, loadPtr));
+      }
+      const uint32_t resultType =
+          theBuilder.getVecType(addressTypeId, numWords);
+      resultId = theBuilder.createCompositeConstruct(resultType, values);
+    }
+  }
+  return resultId;
+}
+
 uint32_t SPIRVEmitter::doCXXMemberCallExpr(const CXXMemberCallExpr *expr) {
 uint32_t SPIRVEmitter::doCXXMemberCallExpr(const CXXMemberCallExpr *expr) {
   using namespace hlsl;
   using namespace hlsl;
 
 
@@ -1419,8 +1507,14 @@ uint32_t SPIRVEmitter::doCXXMemberCallExpr(const CXXMemberCallExpr *expr) {
         return 0;
         return 0;
       }
       }
 
 
-      const auto *imageExpr = expr->getImplicitObjectArgument();
-      const uint32_t image = loadIfGLValue(imageExpr);
+      const auto *object = expr->getImplicitObjectArgument();
+      const auto objectType = object->getType();
+      if (typeTranslator.isRWByteAddressBuffer(objectType) ||
+          typeTranslator.isByteAddressBuffer(objectType)) {
+        return processByteAddressBufferLoadStore(expr, 1, /*doStore*/ false);
+      }
+
+      const uint32_t image = loadIfGLValue(object);
 
 
       // The location parameter is a vector that consists of both the coordinate
       // The location parameter is a vector that consists of both the coordinate
       // and the mipmap level (via the last vector element). We need to split it
       // and the mipmap level (via the last vector element). We need to split it
@@ -1438,6 +1532,27 @@ uint32_t SPIRVEmitter::doCXXMemberCallExpr(const CXXMemberCallExpr *expr) {
       return theBuilder.createImageFetch(retType, image, coordinate, lod,
       return theBuilder.createImageFetch(retType, image, coordinate, lod,
                                          constOffset, varOffset);
                                          constOffset, varOffset);
     }
     }
+    case IntrinsicOp::MOP_Load2: {
+      return processByteAddressBufferLoadStore(expr, 2, /*doStore*/ false);
+    }
+    case IntrinsicOp::MOP_Load3: {
+      return processByteAddressBufferLoadStore(expr, 3, /*doStore*/ false);
+    }
+    case IntrinsicOp::MOP_Load4: {
+      return processByteAddressBufferLoadStore(expr, 4, /*doStore*/ false);
+    }
+    case IntrinsicOp::MOP_Store: {
+      return processByteAddressBufferLoadStore(expr, 1, /*doStore*/ true);
+    }
+    case IntrinsicOp::MOP_Store2: {
+      return processByteAddressBufferLoadStore(expr, 2, /*doStore*/ true);
+    }
+    case IntrinsicOp::MOP_Store3: {
+      return processByteAddressBufferLoadStore(expr, 3, /*doStore*/ true);
+    }
+    case IntrinsicOp::MOP_Store4: {
+      return processByteAddressBufferLoadStore(expr, 4, /*doStore*/ true);
+    }
     default:
     default:
       emitError("HLSL intrinsic member call unimplemented: %0")
       emitError("HLSL intrinsic member call unimplemented: %0")
           << callee->getName();
           << callee->getName();

+ 9 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -415,6 +415,15 @@ private:
   /// statement.
   /// statement.
   void processSwitchStmtUsingIfStmts(const SwitchStmt *switchStmt);
   void processSwitchStmtUsingIfStmts(const SwitchStmt *switchStmt);
 
 
+private:
+  /// \brief Loads numWords 32-bit unsigned integers or stores numWords 32-bit
+  /// unsigned integers (based on the doStore parameter) to the given
+  /// ByteAddressBuffer. Loading is allowed from a ByteAddressBuffer or
+  /// RWByteAddressBuffer. Storing is allowed only to RWByteAddressBuffer.
+  /// Panics if it is not the case.
+  uint32_t processByteAddressBufferLoadStore(const CXXMemberCallExpr *,
+                                             uint32_t numWords, bool doStore);
+
 private:
 private:
   /// \brief Wrapper method to create an error message and report it
   /// \brief Wrapper method to create an error message and report it
   /// in the diagnostic engine associated with this consumer.
   /// in the diagnostic engine associated with this consumer.

+ 2 - 2
tools/clang/lib/SPIRV/Structure.cpp

@@ -275,8 +275,8 @@ void SPIRVModule::take(InstBuilder *builder) {
     }
     }
   }
   }
 
 
-  for (const auto &d : decorations) {
-    consumer(d.decoration.withTargetId(d.targetId));
+  for (const auto &idDecorPair : decorations) {
+    consumer(idDecorPair.second->withTargetId(idDecorPair.first));
   }
   }
 
 
   // Note on interdependence of types and constants:
   // Note on interdependence of types and constants:

+ 5 - 4
tools/clang/lib/SPIRV/Type.cpp

@@ -14,9 +14,10 @@
 namespace clang {
 namespace clang {
 namespace spirv {
 namespace spirv {
 
 
-Type::Type(spv::Op op, std::vector<uint32_t> arg,
-           std::set<const Decoration *> decs)
-    : opcode(op), args(std::move(arg)), decorations(std::move(decs)) {}
+Type::Type(spv::Op op, std::vector<uint32_t> arg, DecorationSet decs)
+    : opcode(op), args(std::move(arg)) {
+  decorations = llvm::SetVector<const Decoration *>(decs.begin(), decs.end());
+}
 
 
 const Type *Type::getUniqueType(SPIRVContext &context, const Type &t) {
 const Type *Type::getUniqueType(SPIRVContext &context, const Type &t) {
   return context.registerType(t);
   return context.registerType(t);
@@ -197,7 +198,7 @@ bool Type::isCompositeType() const {
 bool Type::isImageType() const { return opcode == spv::Op::OpTypeImage; }
 bool Type::isImageType() const { return opcode == spv::Op::OpTypeImage; }
 
 
 bool Type::hasDecoration(const Decoration *d) const {
 bool Type::hasDecoration(const Decoration *d) const {
-  return decorations.find(d) != decorations.end();
+  return decorations.count(d);
 }
 }
 
 
 std::vector<uint32_t> Type::withResultId(uint32_t resultId) const {
 std::vector<uint32_t> Type::withResultId(uint32_t resultId) const {

+ 23 - 0
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -166,6 +166,20 @@ bool TypeTranslator::isScalarType(QualType type, QualType *scalarType) {
   return isScalar;
   return isScalar;
 }
 }
 
 
+bool TypeTranslator::isRWByteAddressBuffer(QualType type) {
+  if (const auto *rt = type->getAs<RecordType>()) {
+    return rt->getDecl()->getName() == "RWByteAddressBuffer";
+  }
+  return false;
+}
+
+bool TypeTranslator::isByteAddressBuffer(QualType type) {
+  if (const auto *rt = type->getAs<RecordType>()) {
+    return rt->getDecl()->getName() == "ByteAddressBuffer";
+  }
+  return false;
+}
+
 bool TypeTranslator::isVectorType(QualType type, QualType *elemType,
 bool TypeTranslator::isVectorType(QualType type, QualType *elemType,
                                   uint32_t *elemCount) {
                                   uint32_t *elemCount) {
   bool isVec = false;
   bool isVec = false;
@@ -334,6 +348,15 @@ uint32_t TypeTranslator::translateResourceType(QualType type) {
     return theBuilder.getSamplerType();
     return theBuilder.getSamplerType();
   }
   }
 
 
+  // ByteAddressBuffer types.
+  if (name == "ByteAddressBuffer") {
+    return theBuilder.getByteAddressBufferType(/*isRW*/ false);
+  }
+  // RWByteAddressBuffer types.
+  if (name == "RWByteAddressBuffer") {
+    return theBuilder.getByteAddressBufferType(/*isRW*/ true);
+  }
+
   return 0;
   return 0;
 }
 }
 
 

+ 6 - 0
tools/clang/lib/SPIRV/TypeTranslator.h

@@ -38,6 +38,12 @@ public:
   /// on will be generated.
   /// on will be generated.
   uint32_t translateType(QualType type);
   uint32_t translateType(QualType type);
 
 
+  /// \brief Returns true if the given type is the HLSL ByteAddressBufferType.
+  bool isByteAddressBuffer(QualType type);
+
+  /// \brief Returns true if the given type is the HLSL RWByteAddressBufferType.
+  bool isRWByteAddressBuffer(QualType type);
+
   /// \brief Returns true if the given type will be translated into a SPIR-V
   /// \brief Returns true if the given type will be translated into a SPIR-V
   /// scalar type. This includes normal scalar types, vectors of size 1, and
   /// scalar type. This includes normal scalar types, vectors of size 1, and
   /// 1x1 matrices. If scalarType is not nullptr, writes the scalar type to
   /// 1x1 matrices. If scalarType is not nullptr, writes the scalar type to

+ 53 - 0
tools/clang/test/CodeGenSPIRV/method.byte-address-buffer.load.hlsl

@@ -0,0 +1,53 @@
+// Run: %dxc -T cs_6_0 -E main
+
+ByteAddressBuffer myBuffer;
+
+[numthreads(1, 1, 1)]
+void main() {
+  uint addr = 0;
+
+// CHECK: [[addr1:%\d+]] = OpLoad %uint %addr
+// CHECK-NEXT: [[word_addr:%\d+]] = OpShiftRightLogical %uint [[addr1]] %uint_2
+// CHECK-NEXT: [[load_ptr:%\d+]] = OpAccessChain %_ptr_Uniform_uint %myBuffer %uint_0 [[word_addr]]
+// CHECK-NEXT: {{%\d+}} = OpLoad %uint [[load_ptr]]
+  uint word = myBuffer.Load(addr);
+
+// CHECK: [[addr3:%\d+]] = OpLoad %uint %addr
+// CHECK-NEXT: [[load2_word0Addr:%\d+]] = OpShiftRightLogical %uint [[addr3]] %uint_2
+// CHECK-NEXT: [[load_ptr10:%\d+]] = OpAccessChain %_ptr_Uniform_uint %myBuffer %uint_0 [[load2_word0Addr]]
+// CHECK-NEXT: [[load2_word0:%\d+]] = OpLoad %uint [[load_ptr10]]
+// CHECK-NEXT: [[load2_word1Addr:%\d+]] = OpIAdd %uint [[load2_word0Addr]] %uint_1
+// CHECK-NEXT: [[load_ptr11:%\d+]] = OpAccessChain %_ptr_Uniform_uint %myBuffer %uint_0 [[load2_word1Addr]]
+// CHECK-NEXT: [[load2_word1:%\d+]] = OpLoad %uint [[load_ptr11]]
+// CHECK-NEXT: {{%\d+}} = OpCompositeConstruct %v2uint [[load2_word0]] [[load2_word1]]
+  uint2 word2 = myBuffer.Load2(addr);
+
+// CHECK: [[addr2:%\d+]] = OpLoad %uint %addr
+// CHECK-NEXT: [[load3_word0Addr:%\d+]] = OpShiftRightLogical %uint [[addr2]] %uint_2
+// CHECK-NEXT: [[load_ptr7:%\d+]] = OpAccessChain %_ptr_Uniform_uint %myBuffer %uint_0 [[load3_word0Addr]]
+// CHECK-NEXT: [[load3_word0:%\d+]] = OpLoad %uint [[load_ptr7]]
+// CHECK-NEXT: [[load3_word1Addr:%\d+]] = OpIAdd %uint [[load3_word0Addr]] %uint_1
+// CHECK-NEXT: [[load_ptr8:%\d+]] = OpAccessChain %_ptr_Uniform_uint %myBuffer %uint_0 [[load3_word1Addr]]
+// CHECK-NEXT: [[load3_word1:%\d+]] = OpLoad %uint [[load_ptr8]]
+// CHECK-NEXT: [[load3_word2Addr:%\d+]] = OpIAdd %uint [[load3_word0Addr]] %uint_2
+// CHECK-NEXT: [[load_ptr9:%\d+]] = OpAccessChain %_ptr_Uniform_uint %myBuffer %uint_0 [[load3_word2Addr]]
+// CHECK-NEXT: [[load3_word2:%\d+]] = OpLoad %uint [[load_ptr9]]
+// CHECK-NEXT: {{%\d+}} = OpCompositeConstruct %v3uint [[load3_word0]] [[load3_word1]] [[load3_word2]]
+  uint3 word3 = myBuffer.Load3(addr);
+
+// CHECK: [[addr:%\d+]] = OpLoad %uint %addr
+// CHECK-NEXT: [[load4_word0Addr:%\d+]] = OpShiftRightLogical %uint [[addr]] %uint_2
+// CHECK-NEXT: [[load_ptr3:%\d+]] = OpAccessChain %_ptr_Uniform_uint %myBuffer %uint_0 [[load4_word0Addr]]
+// CHECK-NEXT: [[load4_word0:%\d+]] = OpLoad %uint [[load_ptr3]]
+// CHECK-NEXT: [[load4_word1Addr:%\d+]] = OpIAdd %uint [[load4_word0Addr]] %uint_1
+// CHECK-NEXT: [[load_ptr4:%\d+]] = OpAccessChain %_ptr_Uniform_uint %myBuffer %uint_0 [[load4_word1Addr]]
+// CHECK-NEXT: [[load4_word1:%\d+]] = OpLoad %uint [[load_ptr4]]
+// CHECK-NEXT: [[load4_word2Addr:%\d+]] = OpIAdd %uint [[load4_word0Addr]] %uint_2
+// CHECK-NEXT: [[load_ptr5:%\d+]] = OpAccessChain %_ptr_Uniform_uint %myBuffer %uint_0 [[load4_word2Addr]]
+// CHECK-NEXT: [[load4_word2:%\d+]] = OpLoad %uint [[load_ptr5]]
+// CHECK-NEXT: [[load4_word3Addr:%\d+]] = OpIAdd %uint [[load4_word0Addr]] %uint_3
+// CHECK-NEXT: [[load_ptr6:%\d+]] = OpAccessChain %_ptr_Uniform_uint %myBuffer %uint_0 [[load4_word3Addr]]
+// CHECK-NEXT: [[load4_word3:%\d+]] = OpLoad %uint [[load_ptr6]]
+// CHECK-NEXT: {{%\d+}} = OpCompositeConstruct %v4uint [[load4_word0]] [[load4_word1]] [[load4_word2]] [[load4_word3]]
+  uint4 word4 = myBuffer.Load4(addr);
+}

+ 70 - 0
tools/clang/test/CodeGenSPIRV/method.byte-address-buffer.store.hlsl

@@ -0,0 +1,70 @@
+// Run: %dxc -T cs_6_0 -E main
+
+RWByteAddressBuffer outBuffer;
+
+[numthreads(1, 1, 1)]
+void main() {
+  uint addr = 0;
+  uint words1 = 1;
+  uint2 words2 = uint2(1, 2);
+  uint3 words3 = uint3(1, 2, 3);
+  uint4 words4 = uint4(1, 2, 3, 4);
+
+// CHECK:      [[byteAddr1:%\d+]] = OpLoad %uint %addr
+// CHECK-NEXT: [[baseAddr1:%\d+]] = OpShiftRightLogical %uint [[byteAddr1]] %uint_2
+// CHECK-NEXT: [[words1:%\d+]] = OpLoad %uint %words1
+// CHECK-NEXT: [[out1_outBufPtr0:%\d+]] = OpAccessChain %_ptr_Uniform_uint %outBuffer %uint_0 [[baseAddr1]]
+// CHECK-NEXT: OpStore [[out1_outBufPtr0]] [[words1]]
+  outBuffer.Store(addr, words1);
+
+
+// CHECK:      [[byteAddr2:%\d+]] = OpLoad %uint %addr
+// CHECK-NEXT: [[baseAddr2:%\d+]] = OpShiftRightLogical %uint [[byteAddr2]] %uint_2
+// CHECK-NEXT: [[words2:%\d+]] = OpLoad %v2uint %words2
+// CHECK-NEXT: [[words2_0:%\d+]] = OpCompositeExtract %uint [[words2]] 0
+// CHECK-NEXT: [[out2_outBufPtr0:%\d+]] = OpAccessChain %_ptr_Uniform_uint %outBuffer %uint_0 [[baseAddr2]]
+// CHECK-NEXT: OpStore [[out2_outBufPtr0]] [[words2_0]]
+// CHECK-NEXT: [[words2_1:%\d+]] = OpCompositeExtract %uint [[words2]] 1
+// CHECK-NEXT: [[baseAddr2_plus1:%\d+]] = OpIAdd %uint [[baseAddr2]] %uint_1
+// CHECK-NEXT: [[out2_outBufPtr1:%\d+]] = OpAccessChain %_ptr_Uniform_uint %outBuffer %uint_0 [[baseAddr2_plus1]]
+// CHECK-NEXT: OpStore [[out2_outBufPtr1]] [[words2_1]]
+  outBuffer.Store2(addr, words2);
+
+
+// CHECK:      [[byteAddr3:%\d+]] = OpLoad %uint %addr
+// CHECK-NEXT: [[baseAddr3:%\d+]] = OpShiftRightLogical %uint [[byteAddr3]] %uint_2
+// CHECK-NEXT: [[words3:%\d+]] = OpLoad %v3uint %words3
+// CHECK-NEXT: [[word3_0:%\d+]] = OpCompositeExtract %uint [[words3]] 0
+// CHECK-NEXT: [[out3_outBufPtr0:%\d+]] = OpAccessChain %_ptr_Uniform_uint %outBuffer %uint_0 [[baseAddr3]]
+// CHECK-NEXT: OpStore [[out3_outBufPtr0]] [[word3_0]]
+// CHECK-NEXT: [[words3_1:%\d+]] = OpCompositeExtract %uint [[words3]] 1
+// CHECK-NEXT: [[baseAddr3_plus1:%\d+]] = OpIAdd %uint [[baseAddr3]] %uint_1
+// CHECK-NEXT: [[out3_outBufPtr1:%\d+]] = OpAccessChain %_ptr_Uniform_uint %outBuffer %uint_0 [[baseAddr3_plus1]]
+// CHECK-NEXT: OpStore [[out3_outBufPtr1]] [[words3_1]]
+// CHECK-NEXT: [[word3_2:%\d+]] = OpCompositeExtract %uint [[words3]] 2
+// CHECK-NEXT: [[baseAddr3_plus2:%\d+]] = OpIAdd %uint [[baseAddr3]] %uint_2
+// CHECK-NEXT: [[out3_outBufPtr2:%\d+]] = OpAccessChain %_ptr_Uniform_uint %outBuffer %uint_0 [[baseAddr3_plus2]]
+// CHECK-NEXT: OpStore [[out3_outBufPtr2]] [[word3_2]]
+  outBuffer.Store3(addr, words3);
+
+
+// CHECK:      [[byteAddr:%\d+]] = OpLoad %uint %addr
+// CHECK-NEXT: [[baseAddr:%\d+]] = OpShiftRightLogical %uint [[byteAddr]] %uint_2
+// CHECK-NEXT: [[words4:%\d+]] = OpLoad %v4uint %words4
+// CHECK-NEXT: [[word0:%\d+]] = OpCompositeExtract %uint [[words4]] 0
+// CHECK-NEXT: [[outBufPtr0:%\d+]] = OpAccessChain %_ptr_Uniform_uint %outBuffer %uint_0 [[baseAddr]]
+// CHECK-NEXT: OpStore [[outBufPtr0]] [[word0]]
+// CHECK-NEXT: [[word1:%\d+]] = OpCompositeExtract %uint [[words4]] 1
+// CHECK-NEXT: [[baseAddr_plus1:%\d+]] = OpIAdd %uint [[baseAddr]] %uint_1
+// CHECK-NEXT: [[outBufPtr1:%\d+]] = OpAccessChain %_ptr_Uniform_uint %outBuffer %uint_0 [[baseAddr_plus1]]
+// CHECK-NEXT: OpStore [[outBufPtr1]] [[word1]]
+// CHECK-NEXT: [[word2:%\d+]] = OpCompositeExtract %uint [[words4]] 2
+// CHECK-NEXT: [[baseAddr_plus2:%\d+]] = OpIAdd %uint [[baseAddr]] %uint_2
+// CHECK-NEXT: [[outBufPtr2:%\d+]] = OpAccessChain %_ptr_Uniform_uint %outBuffer %uint_0 [[baseAddr_plus2]]
+// CHECK-NEXT: OpStore [[outBufPtr2]] [[word2]]
+// CHECK-NEXT: [[word3:%\d+]] = OpCompositeExtract %uint [[words4]] 3
+// CHECK-NEXT: [[baseAddr_plus3:%\d+]] = OpIAdd %uint [[baseAddr]] %uint_3
+// CHECK-NEXT: [[outBufPtr3:%\d+]] = OpAccessChain %_ptr_Uniform_uint %outBuffer %uint_0 [[baseAddr_plus3]]
+// CHECK-NEXT: OpStore [[outBufPtr3]] [[word3]]
+  outBuffer.Store4(addr, words4);
+}

+ 23 - 0
tools/clang/test/CodeGenSPIRV/type.byte-address-buffer.hlsl

@@ -0,0 +1,23 @@
+// Run: %dxc -T cs_6_0 -E main
+
+// CHECK: OpName %type_ByteAddressBuffer "type.ByteAddressBuffer"
+// CHECK: OpName %type_RWByteAddressBuffer "type.RWByteAddressBuffer"
+// CHECK: OpDecorate %_runtimearr_uint ArrayStride 4
+// CHECK: OpDecorate %type_ByteAddressBuffer BufferBlock
+// CHECK: OpMemberDecorate %type_ByteAddressBuffer 0 Offset 0
+// CHECK: OpMemberDecorate %type_ByteAddressBuffer 0 NonWritable
+// CHECK: OpDecorate %type_RWByteAddressBuffer BufferBlock
+// CHECK: OpMemberDecorate %type_RWByteAddressBuffer 0 Offset 0
+// CHECK: %_runtimearr_uint = OpTypeRuntimeArray %uint
+// CHECK: %type_ByteAddressBuffer = OpTypeStruct %_runtimearr_uint
+// CHECK: %_ptr_Uniform_type_ByteAddressBuffer = OpTypePointer Uniform %type_ByteAddressBuffer
+// CHECK: %type_RWByteAddressBuffer = OpTypeStruct %_runtimearr_uint
+// CHECK: %_ptr_Uniform_type_RWByteAddressBuffer = OpTypePointer Uniform %type_RWByteAddressBuffer
+// CHECK: %Buffer0 = OpVariable %_ptr_Uniform_type_ByteAddressBuffer Uniform
+// CHECK: %BufferOut = OpVariable %_ptr_Uniform_type_RWByteAddressBuffer Uniform
+
+ByteAddressBuffer Buffer0;
+RWByteAddressBuffer BufferOut;
+
+[numthreads(1, 1, 1)]
+void main() {}

+ 11 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -43,6 +43,9 @@ TEST_F(FileTest, ArrayTypes) { runFileTest("type.array.hlsl"); }
 TEST_F(FileTest, TypedefTypes) { runFileTest("type.typedef.hlsl"); }
 TEST_F(FileTest, TypedefTypes) { runFileTest("type.typedef.hlsl"); }
 TEST_F(FileTest, SamplerTypes) { runFileTest("type.sampler.hlsl"); }
 TEST_F(FileTest, SamplerTypes) { runFileTest("type.sampler.hlsl"); }
 TEST_F(FileTest, TextureTypes) { runFileTest("type.texture.hlsl"); }
 TEST_F(FileTest, TextureTypes) { runFileTest("type.texture.hlsl"); }
+TEST_F(FileTest, ByteAddressBufferTypes) {
+  runFileTest("type.byte-address-buffer.hlsl");
+}
 
 
 // For constants
 // For constants
 TEST_F(FileTest, ScalarConstants) { runFileTest("constant.scalar.hlsl"); }
 TEST_F(FileTest, ScalarConstants) { runFileTest("constant.scalar.hlsl"); }
@@ -322,6 +325,14 @@ TEST_F(FileTest, TextureArraySampleGrad) {
   runFileTest("texture.array.sample-grad.hlsl");
   runFileTest("texture.array.sample-grad.hlsl");
 }
 }
 
 
+// For ByteAddressBuffer methods
+TEST_F(FileTest, ByteAddressBufferLoad) {
+  runFileTest("method.byte-address-buffer.load.hlsl");
+}
+TEST_F(FileTest, ByteAddressBufferStore) {
+  runFileTest("method.byte-address-buffer.store.hlsl");
+}
+
 // For intrinsic functions
 // For intrinsic functions
 TEST_F(FileTest, IntrinsicsDot) { runFileTest("intrinsics.dot.hlsl"); }
 TEST_F(FileTest, IntrinsicsDot) { runFileTest("intrinsics.dot.hlsl"); }
 TEST_F(FileTest, IntrinsicsMul) { runFileTest("intrinsics.mul.hlsl"); }
 TEST_F(FileTest, IntrinsicsMul) { runFileTest("intrinsics.mul.hlsl"); }

+ 2 - 2
tools/clang/unittests/SPIRV/StructureTest.cpp

@@ -166,7 +166,7 @@ TEST(Structure, TakeModuleHaveAllContents) {
   sib.inst(spv::Op::OpName,
   sib.inst(spv::Op::OpName,
            {entryPointId, mainWord, /* additional null in name */ 0});
            {entryPointId, mainWord, /* additional null in name */ 0});
 
 
-  m.addDecoration(*Decoration::getRelaxedPrecision(context), entryPointId);
+  m.addDecoration(Decoration::getRelaxedPrecision(context), entryPointId);
   sib.inst(
   sib.inst(
       spv::Op::OpDecorate,
       spv::Op::OpDecorate,
       {entryPointId, static_cast<uint32_t>(spv::Decoration::RelaxedPrecision)});
       {entryPointId, static_cast<uint32_t>(spv::Decoration::RelaxedPrecision)});
@@ -269,7 +269,7 @@ TEST(Structure, TakeModuleWithArrayAndConstantDependency) {
       Type::getArray(context, i32Id, constantId, {arrStride});
       Type::getArray(context, i32Id, constantId, {arrStride});
   const uint32_t secondArrId = context.getResultIdForType(arrType);
   const uint32_t secondArrId = context.getResultIdForType(arrType);
   m.addType(secondArrType, secondArrId);
   m.addType(secondArrType, secondArrId);
-  m.addDecoration(*arrStride, secondArrId);
+  m.addDecoration(arrStride, secondArrId);
   m.setBound(context.getNextId());
   m.setBound(context.getNextId());
 
 
   // Decorations
   // Decorations