Explorar el Código

[spirv] Start effort to move to new infrastructure.

Ehsan Nasiri hace 6 años
padre
commit
d9bc292437

+ 0 - 1
lib/Support/YAMLTraits.cpp

@@ -10,7 +10,6 @@
 #include "llvm/Support/YAMLTraits.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/Twine.h"
-#include "llvm/Support/Casting.h"
 #include "llvm/Support/Errc.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/Format.h"

+ 29 - 0
tools/clang/include/clang/SPIRV/AstTypeProbe.h

@@ -63,6 +63,35 @@ bool isMx1Matrix(QualType type, QualType *elemType = nullptr,
 bool isMxNMatrix(QualType type, QualType *elemType = nullptr,
                  uint32_t *rowCount = nullptr, uint32_t *colCount = nullptr);
 
+/// Returns true if the given type is or contains any kind of structured-buffer
+/// or byte-address-buffer.
+bool isOrContainsAKindOfStructuredOrByteBuffer(QualType type);
+
+/// \brief Returns true if the given type is SubpassInput.
+bool isSubpassInput(QualType);
+
+/// \brief Returns true if the given type is SubpassInputMS.
+bool isSubpassInputMS(QualType);
+
+/// \brief Returns true if the decl is of ConstantBuffer/TextureBuffer type.
+bool isConstantTextureBuffer(const Decl *decl);
+
+/// \brief Returns true if the decl will have a SPIR-V resource type.
+///
+/// Note that this function covers the following HLSL types:
+/// * ConstantBuffer/TextureBuffer
+/// * Various structured buffers
+/// * (RW)ByteAddressBuffer
+/// * SubpassInput(MS)
+bool isResourceType(const ValueDecl *decl);
+
+/// \brief Returns true if the given type is or contains 16-bit type.
+//bool isOrContains16BitType(QualType type);
+
+  /// \brief Returns true if the given type is the HLSL (RW)StructuredBuffer,
+  /// (RW)ByteAddressBuffer, or {Append|Consume}StructuredBuffer.
+bool isAKindOfStructuredOrByteBuffer(QualType type);
+
 } // namespace spirv
 } // namespace clang
 

+ 9 - 0
tools/clang/include/clang/SPIRV/SPIRVContext.h

@@ -194,6 +194,12 @@ public:
                 StructType::InterfaceType interfaceType =
                     StructType::InterfaceType::InternalStorage);
 
+  const HybridStructType *
+  getHybridStructType(llvm::ArrayRef<HybridStructType::FieldInfo> fields,
+                      llvm::StringRef name, bool isReadOnly = false,
+                      HybridStructType::InterfaceType interfaceType =
+                          HybridStructType::InterfaceType::InternalStorage);
+
   const SpirvPointerType *getPointerType(const SpirvType *pointee,
                                          spv::StorageClass);
 
@@ -201,8 +207,10 @@ public:
                                       llvm::ArrayRef<const SpirvType *> param);
 
   const StructType *getByteAddressBufferType(bool isWritable);
+  const StructType *getACSBufferCounterType();
 
   SpirvConstant *getConstantUint32(uint32_t value);
+  SpirvConstant *getConstantInt32(int32_t value);
   // TODO: Add getConstant* methods for other types.
 
 private:
@@ -248,6 +256,7 @@ private:
   llvm::DenseMap<const SpirvType *, const RuntimeArrayType *> runtimeArrayTypes;
 
   llvm::SmallVector<const StructType *, 8> structTypes;
+  llvm::SmallVector<const HybridStructType *, 8> hybridStructTypes;
 
   llvm::DenseMap<const SpirvType *, SCToPtrTyMap> pointerTypes;
 

+ 31 - 7
tools/clang/include/clang/SPIRV/SpirvBuilder.h

@@ -63,6 +63,14 @@ public:
   SpirvFunctionParameter *addFnParam(QualType ptrType, SourceLocation,
                                      llvm::StringRef name = "");
 
+  /// \brief Creates a SpirvFunction object and adds it to the list of module
+  /// functions. This does not change the current function under construction.
+  /// The handle can be used to create function call instructions for functions
+  /// that we have not yet discovered in the source code.
+  SpirvFunction *createFunction(QualType returnType, SourceLocation,
+                                llvm::StringRef name = "",
+                                bool isAlias = false);
+
   /// \brief Creates a local variable of the given type in the current
   /// function and returns it.
   ///
@@ -113,6 +121,10 @@ public:
   createCompositeConstruct(QualType resultType,
                            llvm::ArrayRef<SpirvInstruction *> constituents,
                            SourceLocation loc = {});
+  SpirvComposite *
+  createCompositeConstruct(const SpirvType *resultType,
+                           llvm::ArrayRef<SpirvInstruction *> constituents,
+                           SourceLocation loc = {});
 
   /// \brief Creates a composite extract instruction. The given composite is
   /// indexed using the given literal indexes to obtain the resulting element.
@@ -145,6 +157,8 @@ public:
   /// the loaded value.
   SpirvLoad *createLoad(QualType resultType, SpirvInstruction *pointer,
                         SourceLocation loc = {});
+  SpirvLoad *createLoad(const SpirvType *resultType, SpirvInstruction *pointer,
+                        SourceLocation loc = {});
 
   /// \brief Creates a store instruction storing the given value into the given
   /// address.
@@ -165,6 +179,10 @@ public:
   createAccessChain(QualType resultType, SpirvInstruction *base,
                     llvm::ArrayRef<SpirvInstruction *> indexes,
                     SourceLocation loc = {});
+  SpirvAccessChain *
+  createAccessChain(const SpirvType *resultType, SpirvInstruction *base,
+                    llvm::ArrayRef<SpirvInstruction *> indexes,
+                    SourceLocation loc = {});
 
   /// \brief Creates a unary operation with the given SPIR-V opcode. Returns
   /// the instruction pointer for the result.
@@ -403,8 +421,7 @@ public:
   inline void setSourceFileContent(llvm::StringRef content);
 
   /// \brief Adds an execution mode to the module under construction.
-  inline void addExecutionMode(SpirvEntryPoint *entryPoint,
-                               spv::ExecutionMode em,
+  inline void addExecutionMode(SpirvFunction *entryPoint, spv::ExecutionMode em,
                                llvm::ArrayRef<uint32_t> params,
                                SourceLocation loc = {});
 
@@ -428,7 +445,7 @@ public:
   ///
   /// Note: The corresponding pointer type of the given type will not be
   /// constructed in this method.
-  SpirvVariable *addStageBuiltinVar(QualType type,
+  SpirvVariable *addStageBuiltinVar(const SpirvType *type,
                                     spv::StorageClass storageClass,
                                     spv::BuiltIn, SourceLocation loc = {});
 
@@ -442,6 +459,13 @@ public:
                llvm::StringRef name = "",
                llvm::Optional<SpirvInstruction *> init = llvm::None,
                SourceLocation loc = {});
+  // TODO(ehsan): This API should be removed once aliasing has been moved to a
+  // pass.
+  SpirvVariable *
+  addModuleVar(const SpirvType *valueType, spv::StorageClass storageClass,
+               llvm::StringRef name = "",
+               llvm::Optional<SpirvInstruction *> init = llvm::None,
+               SourceLocation loc = {});
 
   /// \brief Decorates the given target with the given location.
   void decorateLocation(SpirvInstruction *target, uint32_t location,
@@ -467,9 +491,9 @@ public:
                                     SourceLocation srcLoc = {});
 
   /// \brief Decorates the given main buffer with the given counter buffer.
-  void decorateCounterBufferId(SpirvInstruction *mainBuffer,
-                               uint32_t counterBufferId,
-                               SourceLocation srcLoc = {});
+  void decorateCounterBuffer(SpirvInstruction *mainBuffer,
+                             SpirvInstruction *counterBuffer,
+                             SourceLocation srcLoc = {});
 
   /// \brief Decorates the given target with the given HLSL semantic string.
   void decorateHlslSemantic(SpirvInstruction *target, llvm::StringRef semantic,
@@ -564,7 +588,7 @@ void SpirvBuilder::setSourceFileContent(llvm::StringRef content) {
   module->setSourceFileContent(content);
 }
 
-void SpirvBuilder::addExecutionMode(SpirvEntryPoint *entryPoint,
+void SpirvBuilder::addExecutionMode(SpirvFunction *entryPoint,
                                     spv::ExecutionMode em,
                                     llvm::ArrayRef<uint32_t> params,
                                     SourceLocation loc) {

+ 9 - 0
tools/clang/include/clang/SPIRV/SpirvFunction.h

@@ -66,6 +66,12 @@ public:
   // Returns the result-id of the OpTypeFunction
   uint32_t getFunctionTypeId() const { return fnTypeId; }
 
+  void setConstainsAliasComponent(bool isAlias) { containsAlias = isAlias; }
+  bool constainsAliasComponent() { return containsAlias; }
+
+  void setRValue() { rvalue = true; }
+  bool isRValue() { return rvalue; }
+
   void setFunctionName(llvm::StringRef name) { functionName = name; }
   llvm::StringRef getFunctionName() const { return functionName; }
 
@@ -83,6 +89,9 @@ private:
   FunctionType *fnType; ///< The SPIR-V function type
   uint32_t fnTypeId;    ///< result-id for the SPIR-V function type
 
+  bool containsAlias; ///< Whether function return type is aliased
+  bool rvalue;        ///< Whether the return value is an rvalue
+
   spv::FunctionControlMask functionControl; ///< SPIR-V function control
   SourceLocation functionLoc;               ///< Location in source code
   std::string functionName;                 ///< This function's name

+ 62 - 7
tools/clang/include/clang/SPIRV/SpirvInstruction.h

@@ -127,7 +127,7 @@ public:
 
   bool hasResultType() const { return resultType != nullptr; }
   const SpirvType *getResultType() const { return resultType; }
-  void setResultType(const SpirvType *t) { resultType = t; }
+  void setResultType(const SpirvType *type) { resultType = type; }
 
   // TODO: The responsibility of assigning the result-id of an instruction
   // shouldn't be on the instruction itself.
@@ -142,6 +142,27 @@ public:
   SpirvLayoutRule getLayoutRule() const { return layoutRule; }
   void setLayoutRule(SpirvLayoutRule rule) { layoutRule = rule; }
 
+  void setContainsAliasComponent(bool contains) { containsAlias = contains; }
+  bool containsAliasComponent() const { return containsAlias; }
+
+  void setStorageClass(spv::StorageClass sc) { storageClass = sc; }
+  spv::StorageClass getStorageClass() const { return storageClass; }
+
+  void setRValue(bool rvalue = true) { isRValue_ = rvalue; }
+  bool isRValue() const { return isRValue_; }
+
+  void setConstant() { isConstant_ = true; }
+  bool isConstant() const { return isConstant_; }
+
+  void setSpecConstant() { isSpecConstant_ = true; }
+  bool isSpecConstant() const { return isSpecConstant_; }
+
+  void setRelaxedPrecision() { isRelaxedPrecision_ = true; }
+  bool isRelaxedPrecision() const { return isRelaxedPrecision_; }
+
+  void setNonUniform(bool nu = true) { isNonUniform_ = true; }
+  bool isNonUniform() const { return isNonUniform_; }
+
 protected:
   // Forbid creating SpirvInstruction directly
   SpirvInstruction(Kind kind, spv::Op opcode, QualType astResultType,
@@ -158,6 +179,23 @@ protected:
   const SpirvType *resultType;
   uint32_t resultTypeId;
   SpirvLayoutRule layoutRule;
+
+  /// Indicates whether this evaluation result contains alias variables
+  ///
+  /// This field should only be true for stand-alone alias variables, which is
+  /// of pointer-to-pointer type, or struct variables containing alias fields.
+  /// After dereferencing the alias variable, this should be set to false to let
+  /// CodeGen fall back to normal handling path.
+  ///
+  /// Note: legalization specific code
+  bool containsAlias;
+
+  spv::StorageClass storageClass;
+  bool isRValue_;
+  bool isConstant_;
+  bool isSpecConstant_;
+  bool isRelaxedPrecision_;
+  bool isNonUniform_;
 };
 
 #define DECLARE_INVOKE_VISITOR_FOR_CLASS(cls)                                  \
@@ -267,7 +305,7 @@ private:
 /// \brief OpExecutionMode and OpExecutionModeId instructions
 class SpirvExecutionMode : public SpirvInstruction {
 public:
-  SpirvExecutionMode(SourceLocation loc, SpirvEntryPoint *entryPoint,
+  SpirvExecutionMode(SourceLocation loc, SpirvFunction *entryPointFunction,
                      spv::ExecutionMode, llvm::ArrayRef<uint32_t> params,
                      bool usesIdParams);
 
@@ -278,12 +316,12 @@ public:
 
   DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvExecutionMode)
 
-  SpirvEntryPoint *getEntryPoint() const { return entryPoint; }
+  SpirvFunction *getEntryPoint() const { return entryPoint; }
   spv::ExecutionMode getExecutionMode() const { return execMode; }
   llvm::ArrayRef<uint32_t> getParams() const { return params; }
 
 private:
-  SpirvEntryPoint *entryPoint;
+  SpirvFunction *entryPoint;
   spv::ExecutionMode execMode;
   llvm::SmallVector<uint32_t, 4> params;
 };
@@ -350,7 +388,7 @@ private:
   std::string process;
 };
 
-/// \brief OpDecorate and OpMemberDecorate instructions
+/// \brief OpDecorate(Id) and OpMemberDecorate instructions
 class SpirvDecoration : public SpirvInstruction {
 public:
   SpirvDecoration(SourceLocation loc, SpirvInstruction *target,
@@ -360,6 +398,11 @@ public:
                   spv::Decoration decor, llvm::StringRef stringParam,
                   llvm::Optional<uint32_t> index = llvm::None);
 
+  // Used for creating OpDecorateId instructions
+  SpirvDecoration(SourceLocation loc, SpirvInstruction *target,
+                  spv::Decoration decor,
+                  llvm::ArrayRef<SpirvInstruction *> params);
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_Decoration;
@@ -372,6 +415,7 @@ public:
 
   spv::Decoration getDecoration() const { return decoration; }
   llvm::ArrayRef<uint32_t> getParams() const { return params; }
+  llvm::ArrayRef<SpirvInstruction *> getIdParams() const { return idParams; }
   bool isMemberDecoration() const { return index.hasValue(); }
   uint32_t getMemberIndex() const { return index.getValue(); }
 
@@ -380,11 +424,21 @@ private:
   spv::Decoration decoration;
   llvm::Optional<uint32_t> index;
   llvm::SmallVector<uint32_t, 4> params;
+  llvm::SmallVector<SpirvInstruction *, 4> idParams;
 };
 
 /// \brief OpVariable instruction
 class SpirvVariable : public SpirvInstruction {
 public:
+  /// \brief An enum class for representing what the DeclContext is used for
+  enum class ContextUsageKind {
+    CBuffer = 0,
+    TBuffer = 1,
+    PushConstant = 2,
+    Globals = 3,
+    None = 4
+  };
+
   SpirvVariable(QualType resultType, uint32_t resultId, SourceLocation loc,
                 spv::StorageClass sc, SpirvInstruction *initializerId = 0);
 
@@ -397,11 +451,12 @@ public:
 
   bool hasInitializer() const { return initializer != nullptr; }
   SpirvInstruction *getInitializer() const { return initializer; }
-  spv::StorageClass getStorageClass() const { return storageClass; }
+  void setContextUsageKind(ContextUsageKind k) { contextUsageKind = k; }
+  ContextUsageKind getContextUsageKind() const { return contextUsageKind; }
 
 private:
-  spv::StorageClass storageClass;
   SpirvInstruction *initializer;
+  ContextUsageKind contextUsageKind;
 };
 
 class SpirvFunctionParameter : public SpirvInstruction {

+ 2 - 1
tools/clang/include/clang/SPIRV/SpirvModule.h

@@ -13,6 +13,7 @@
 
 #include "clang/SPIRV/SpirvFunction.h"
 #include "clang/SPIRV/SpirvInstruction.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallVector.h"
 
 namespace clang {
@@ -103,7 +104,7 @@ private:
   std::vector<SpirvVariable *> variables;
 
   // Shader logic instructions
-  std::vector<SpirvFunction *> functions;
+  llvm::SetVector<SpirvFunction *> functions;
   std::string sourceFileName;
   std::string sourceFileContent;
 };

+ 65 - 0
tools/clang/include/clang/SPIRV/SpirvType.h

@@ -37,6 +37,7 @@ public:
     TK_Array,
     TK_RuntimeArray,
     TK_Struct,
+    TK_HybridStruct, // TODO: Remove once HybridStrcut type is removed.
     TK_Pointer,
     TK_Function,
   };
@@ -336,6 +337,70 @@ private:
   llvm::SmallVector<const SpirvType *, 8> paramTypes;
 };
 
+/// **NOTE**: This type is created in order to facilitate transition of old
+/// infrastructure to the new infrastructure. Using this type should be avoided
+/// as much as possible.
+///
+/// This type uses a mix of SpirvType and QualType for the structure fields.
+class HybridStructType : public SpirvType {
+public:
+  enum class InterfaceType : uint32_t {
+    InternalStorage = 0,
+    StorageBuffer = 1,
+    UniformBuffer = 2,
+  };
+
+  struct FieldInfo {
+  public:
+    FieldInfo(QualType astType_, const SpirvType *type_,
+              llvm::StringRef name_ = "", clang::VKOffsetAttr *offset = nullptr,
+              hlsl::ConstantPacking *packOffset = nullptr)
+        : astType(astType_), spirvType(type_), name(name_),
+          vkOffsetAttr(offset), packOffsetAttr(packOffset) {}
+
+    bool operator==(const FieldInfo &that) const;
+
+    // The field's type.
+    QualType astType;
+    const SpirvType *spirvType;
+    // The field's name.
+    std::string name;
+    // vk::offset attributes associated with this field.
+    clang::VKOffsetAttr *vkOffsetAttr;
+    // :packoffset() annotations associated with this field.
+    hlsl::ConstantPacking *packOffsetAttr;
+  };
+
+  HybridStructType(
+      llvm::ArrayRef<FieldInfo> fields, llvm::StringRef name, bool isReadOnly,
+      InterfaceType interfaceType = InterfaceType::InternalStorage);
+
+  static bool classof(const SpirvType *t) {
+    return t->getKind() == TK_HybridStruct;
+  }
+
+  llvm::ArrayRef<FieldInfo> getFields() const { return fields; }
+  bool isReadOnly() const { return readOnly; }
+  std::string getStructName() const { return structName; }
+  InterfaceType getInterfaceType() const { return interfaceType; }
+
+  bool operator==(const HybridStructType &that) const;
+
+private:
+  // Reflection is heavily used in graphics pipelines. Reflection relies on
+  // struct names and field names. That basically means we cannot ignore these
+  // names when considering unification. Otherwise, reflection will be confused.
+
+  llvm::SmallVector<FieldInfo, 8> fields;
+  std::string structName;
+  bool readOnly;
+  // Indicates the interface type of this structure. If this structure is a
+  // storage buffer shader-interface, it will be decorated with 'BufferBlock'.
+  // If this structure is a uniform buffer shader-interface, it will be
+  // decorated with 'Block'.
+  InterfaceType interfaceType;
+};
+
 } // end namespace spirv
 } // end namespace clang
 

+ 70 - 0
tools/clang/lib/SPIRV/AstTypeProbe.cpp

@@ -211,5 +211,75 @@ bool isMxNMatrix(QualType type, QualType *elemType, uint32_t *numRows,
   return false;
 }
 
+bool isOrContainsAKindOfStructuredOrByteBuffer(QualType type) {
+  if (const RecordType *recordType = type->getAs<RecordType>()) {
+    StringRef name = recordType->getDecl()->getName();
+    if (name == "StructuredBuffer" || name == "RWStructuredBuffer" ||
+        name == "ByteAddressBuffer" || name == "RWByteAddressBuffer" ||
+        name == "AppendStructuredBuffer" || name == "ConsumeStructuredBuffer")
+      return true;
+
+    for (const auto *field : recordType->getDecl()->fields()) {
+      if (isOrContainsAKindOfStructuredOrByteBuffer(field->getType()))
+        return true;
+    }
+  }
+  return false;
+}
+
+bool isSubpassInput(QualType type) {
+  if (const auto *rt = type->getAs<RecordType>())
+    return rt->getDecl()->getName() == "SubpassInput";
+
+  return false;
+}
+
+bool isSubpassInputMS(QualType type) {
+  if (const auto *rt = type->getAs<RecordType>())
+    return rt->getDecl()->getName() == "SubpassInputMS";
+
+  return false;
+}
+
+bool isConstantTextureBuffer(const Decl *decl) {
+  if (const auto *bufferDecl = dyn_cast<HLSLBufferDecl>(decl->getDeclContext()))
+    // Make sure we are not returning true for VarDecls inside cbuffer/tbuffer.
+    return bufferDecl->isConstantBufferView();
+
+  return false;
+}
+
+bool isResourceType(const ValueDecl *decl) {
+  if (isConstantTextureBuffer(decl))
+    return true;
+
+  QualType declType = decl->getType();
+
+  // Deprive the arrayness to see the element type
+  while (declType->isArrayType()) {
+    declType = declType->getAsArrayTypeUnsafe()->getElementType();
+  }
+
+  if (isSubpassInput(declType) || isSubpassInputMS(declType))
+    return true;
+
+  return hlsl::IsHLSLResourceType(declType);
+}
+
+bool isAKindOfStructuredOrByteBuffer(QualType type) {
+  // Strip outer arrayness first
+  while (type->isArrayType())
+    type = type->getAsArrayTypeUnsafe()->getElementType();
+
+  if (const RecordType *recordType = type->getAs<RecordType>()) {
+    StringRef name = recordType->getDecl()->getName();
+    return name == "StructuredBuffer" || name == "RWStructuredBuffer" ||
+           name == "ByteAddressBuffer" || name == "RWByteAddressBuffer" ||
+           name == "AppendStructuredBuffer" ||
+           name == "ConsumeStructuredBuffer";
+  }
+  return false;
+}
+
 } // namespace spirv
 } // namespace clang

+ 279 - 226
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -22,6 +22,7 @@
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/StringMap.h"
 #include "llvm/ADT/StringSet.h"
+#include "llvm/Support/Casting.h"
 
 #include "SPIRVEmitter.h"
 
@@ -29,6 +30,92 @@ namespace clang {
 namespace spirv {
 
 namespace {
+
+bool shouldSkipInStructLayout(const Decl *decl) {
+  // Ignore implicit generated struct declarations/constructors/destructors
+  if (decl->isImplicit())
+    return true;
+  // Ignore embedded type decls
+  if (isa<TypeDecl>(decl))
+    return true;
+  // Ignore embeded function decls
+  if (isa<FunctionDecl>(decl))
+    return true;
+  // Ignore empty decls
+  if (isa<EmptyDecl>(decl))
+    return true;
+
+  // For the $Globals cbuffer, we only care about externally-visiable
+  // non-resource-type variables. The rest should be filtered out.
+
+  const auto *declContext = decl->getDeclContext();
+
+  // Special check for ConstantBuffer/TextureBuffer, whose DeclContext is a
+  // HLSLBufferDecl. So that we need to check the HLSLBufferDecl's parent decl
+  // to check whether this is a ConstantBuffer/TextureBuffer defined in the
+  // global namespace.
+  // Note that we should not be seeing ConstantBuffer/TextureBuffer for normal
+  // cbuffer/tbuffer or push constant blocks. So this case should only happen
+  // for $Globals cbuffer.
+  if (isConstantTextureBuffer(decl) &&
+      declContext->getLexicalParent()->isTranslationUnit())
+    return true;
+
+  // $Globals' "struct" is the TranslationUnit, so we should ignore resources
+  // in the TranslationUnit "struct" and its child namespaces.
+  if (declContext->isTranslationUnit() || declContext->isNamespace()) {
+    // External visibility
+    if (const auto *declDecl = dyn_cast<DeclaratorDecl>(decl))
+      if (!declDecl->hasExternalFormalLinkage())
+        return true;
+
+    // cbuffer/tbuffer
+    if (isa<HLSLBufferDecl>(decl))
+      return true;
+
+    // Other resource types
+    if (const auto *valueDecl = dyn_cast<ValueDecl>(decl))
+      if (isResourceType(valueDecl))
+        return true;
+  }
+
+  return false;
+}
+
+void collectDeclsInNamespace(const NamespaceDecl *nsDecl,
+                             llvm::SmallVector<const Decl *, 4> *decls) {
+  for (const auto *decl : nsDecl->decls()) {
+    collectDeclsInField(decl, decls);
+  }
+}
+
+void collectDeclsInField(const Decl *field,
+                         llvm::SmallVector<const Decl *, 4> *decls) {
+
+  // Case of nested namespaces.
+  if (const auto *nsDecl = dyn_cast<NamespaceDecl>(field)) {
+    collectDeclsInNamespace(nsDecl, decls);
+  }
+
+  if (shouldSkipInStructLayout(field))
+    return;
+
+  if (!isa<DeclaratorDecl>(field)) {
+    return;
+  }
+
+  decls->push_back(field);
+}
+
+llvm::SmallVector<const Decl *, 4>
+collectDeclsInDeclContext(const DeclContext *declContext) {
+  llvm::SmallVector<const Decl *, 4> decls;
+  for (const auto *field : declContext->decls()) {
+    collectDeclsInField(field, &decls);
+  }
+  return decls;
+}
+
 /// \brief Returns true if the given decl is a boolean stage I/O variable.
 /// Returns false if the type is not boolean, or the decl is a built-in stage
 /// variable.
@@ -132,14 +219,15 @@ std::string StageVar::getSemanticStr() const {
   return ss.str();
 }
 
-uint32_t CounterIdAliasPair::get(ModuleBuilder &builder,
-                                 TypeTranslator &translator) const {
+SpirvInstruction *CounterIdAliasPair::get(SpirvBuilder &builder,
+                                          SpirvContext &spvContext) const {
   if (isAlias) {
-    const uint32_t counterVarType = builder.getPointerType(
-        translator.getACSBufferCounter(), spv::StorageClass::Uniform);
-    return builder.createLoad(counterVarType, resultId);
+    const auto *counterType = spvContext.getACSBufferCounterType();
+    const auto *counterVarType =
+        spvContext.getPointerType(counterType, spv::StorageClass::Uniform);
+    return builder.createLoad(counterVarType, counterVar);
   }
-  return resultId;
+  return counterVar;
 }
 
 const CounterIdAliasPair *
@@ -151,14 +239,14 @@ CounterVarFields::get(const llvm::SmallVectorImpl<uint32_t> &indices) const {
 }
 
 bool CounterVarFields::assign(const CounterVarFields &srcFields,
-                              ModuleBuilder &builder,
-                              TypeTranslator &translator) const {
+                              SpirvBuilder &builder,
+                              SpirvContext &context) const {
   for (const auto &field : fields) {
     const auto *srcField = srcFields.get(field.indices);
     if (!srcField)
       return false;
 
-    field.counterVar.assign(*srcField, builder, translator);
+    field.counterVar.assign(*srcField, builder, context);
   }
 
   return true;
@@ -167,10 +255,10 @@ bool CounterVarFields::assign(const CounterVarFields &srcFields,
 bool CounterVarFields::assign(const CounterVarFields &srcFields,
                               const llvm::SmallVector<uint32_t, 4> &dstPrefix,
                               const llvm::SmallVector<uint32_t, 4> &srcPrefix,
-                              ModuleBuilder &builder,
-                              TypeTranslator &translator) const {
+                              SpirvBuilder &builder,
+                              SpirvContext &context) const {
   if (dstPrefix.empty() && srcPrefix.empty())
-    return assign(srcFields, builder, translator);
+    return assign(srcFields, builder, context);
 
   llvm::SmallVector<uint32_t, 4> srcIndices = srcPrefix;
 
@@ -198,7 +286,7 @@ bool CounterVarFields::assign(const CounterVarFields &srcFields,
       if (!srcField)
         return false;
 
-      field.counterVar.assign(*srcField, builder, translator);
+      field.counterVar.assign(*srcField, builder, context);
       for (uint32_t i = srcPrefix.size(); i < srcIndices.size(); ++i)
         srcIndices.pop_back();
     }
@@ -221,7 +309,7 @@ SemanticInfo DeclResultIdMapper::getStageVarSemantic(const NamedDecl *decl) {
 }
 
 bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl,
-                                              uint32_t storedValue,
+                                              SpirvInstruction *storedValue,
                                               bool forPCF) {
   QualType type = getTypeOrFnRetType(decl);
 
@@ -245,7 +333,7 @@ bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl,
   // Write back of stage output variables in GS is manually controlled by
   // .Append() intrinsic method, implemented in writeBackOutputStream(). So
   // ignoreValue should be set to true for GS.
-  const bool noWriteBack = storedValue == 0 || shaderModel.IsGS();
+  const bool noWriteBack = storedValue == nullptr || shaderModel.IsGS();
 
   return createStageVars(sigPoint, decl, /*asInput=*/false, type,
                          /*arraySize=*/0, "out.var", llvm::None, &storedValue,
@@ -254,8 +342,8 @@ bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl,
 
 bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl,
                                               uint32_t arraySize,
-                                              uint32_t invocationId,
-                                              uint32_t storedValue) {
+                                              SpirvInstruction *invocationId,
+                                              SpirvInstruction *storedValue) {
   assert(shaderModel.IsHS());
 
   QualType type = getTypeOrFnRetType(decl);
@@ -271,7 +359,7 @@ bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl,
 }
 
 bool DeclResultIdMapper::createStageInputVar(const ParmVarDecl *paramDecl,
-                                             uint32_t *loadedValue,
+                                             SpirvInstruction **loadedValue,
                                              bool forPCF) {
   uint32_t arraySize = 0;
   QualType type = paramDecl->getType();
@@ -310,25 +398,21 @@ DeclResultIdMapper::getDeclSpirvInfo(const ValueDecl *decl) const {
   return nullptr;
 }
 
-SpirvEvalInfo DeclResultIdMapper::getDeclEvalInfo(const ValueDecl *decl) {
+SpirvInstruction *DeclResultIdMapper::getDeclEvalInfo(const ValueDecl *decl) {
   if (const auto *info = getDeclSpirvInfo(decl)) {
     if (info->indexInCTBuffer >= 0) {
       // If this is a VarDecl inside a HLSLBufferDecl, we need to do an extra
       // OpAccessChain to get the pointer to the variable since we created
       // a single variable for the whole buffer object.
 
-      const uint32_t varType = typeTranslator.translateType(
-          // Should only have VarDecls in a HLSLBufferDecl.
-          cast<VarDecl>(decl)->getType(),
-          // We need to set decorateLayout here to avoid creating SPIR-V
-          // instructions for the current type without decorations.
-          info->info.getLayoutRule());
-
-      const uint32_t elemId = theBuilder.createAccessChain(
-          theBuilder.getPointerType(varType, info->info.getStorageClass()),
-          info->info, {theBuilder.getConstantInt32(info->indexInCTBuffer)});
+      // Should only have VarDecls in a HLSLBufferDecl.
+      QualType valueType = cast<VarDecl>(decl)->getType();
 
-      return info->info.substResultId(elemId);
+      // TODO(ehsan): Setting QualType of the value for the access chain. This
+      // used to be a pointer-to-transalted-qualtype.
+      return spvBuilder.createAccessChain(
+          valueType, info->instr,
+          {spvContext.getConstantInt32(info->indexInCTBuffer)});
     } else {
       return *info;
     }
@@ -343,17 +427,23 @@ SpirvEvalInfo DeclResultIdMapper::getDeclEvalInfo(const ValueDecl *decl) {
   return 0;
 }
 
-uint32_t DeclResultIdMapper::createFnParam(const ParmVarDecl *param) {
+SpirvFunctionParameter *
+DeclResultIdMapper::createFnParam(const ParmVarDecl *param) {
+  // TODO(ehsan): Setting QualType for function parameter. In SPIR-V, this
+  // should be a pointer-to-translated-qualtype (with Function storage class).
+  const auto type = getTypeOrFnRetType(param);
+  const auto loc = param->getLocation();
+  SpirvFunctionParameter *fnParamInstr =
+      spvBuilder.addFnParam(type, loc, param->getName());
+
   bool isAlias = false;
-  auto &info = astDecls[param].info;
-  const uint32_t type =
-      getTypeAndCreateCounterForPotentialAliasVar(param, &isAlias, &info);
-  const uint32_t ptrType =
-      theBuilder.getPointerType(type, spv::StorageClass::Function);
-  const uint32_t id = theBuilder.addFnParam(ptrType, param->getName());
-  info.setResultId(id);
-
-  return id;
+  (void)getTypeAndCreateCounterForPotentialAliasVar(param, &isAlias);
+  fnParamInstr->setContainsAliasComponent(isAlias);
+
+  assert(astDecls[param].instr == nullptr);
+  astDecls[param].instr = fnParamInstr;
+
+  return fnParamInstr;
 }
 
 void DeclResultIdMapper::createCounterVarForDecl(const DeclaratorDecl *decl) {
@@ -369,32 +459,44 @@ void DeclResultIdMapper::createCounterVarForDecl(const DeclaratorDecl *decl) {
   }
 }
 
-SpirvEvalInfo DeclResultIdMapper::createFnVar(const VarDecl *var,
-                                              llvm::Optional<uint32_t> init) {
+SpirvVariable *
+DeclResultIdMapper::createFnVar(const VarDecl *var,
+                                llvm::Optional<SpirvInstruction *> init) {
+  const auto type = getTypeOrFnRetType(var);
+  const auto loc = var->getLocation();
+  const auto name = var->getName();
+  SpirvVariable *varInstr = spvBuilder.addFnVar(
+      type, loc, name, init.hasValue() ? init.getValue() : nullptr);
+
   bool isAlias = false;
-  auto &info = astDecls[var].info;
-  const uint32_t type =
-      getTypeAndCreateCounterForPotentialAliasVar(var, &isAlias, &info);
-  const uint32_t id = theBuilder.addFnVar(type, var->getName(), init);
-  info.setResultId(id);
+  (void)getTypeAndCreateCounterForPotentialAliasVar(var, &isAlias);
+  varInstr->setContainsAliasComponent(isAlias);
+
+  assert(astDecls[var].instr == nullptr);
+  astDecls[var].instr = varInstr;
 
-  return info;
+  return varInstr;
 }
 
-SpirvEvalInfo DeclResultIdMapper::createFileVar(const VarDecl *var,
-                                                llvm::Optional<uint32_t> init) {
+SpirvVariable *
+DeclResultIdMapper::createFileVar(const VarDecl *var,
+                                  llvm::Optional<SpirvInstruction *> init) {
+  const auto type = getTypeOrFnRetType(var);
+  const auto loc = var->getLocation();
+  SpirvVariable *varInstr = spvBuilder.addModuleVar(
+      type, spv::StorageClass::Private, var->getName(), init, loc);
+
   bool isAlias = false;
-  auto &info = astDecls[var].info;
-  const uint32_t type =
-      getTypeAndCreateCounterForPotentialAliasVar(var, &isAlias, &info);
-  const uint32_t id = theBuilder.addModuleVar(type, spv::StorageClass::Private,
-                                              var->getName(), init);
-  info.setResultId(id).setStorageClass(spv::StorageClass::Private);
-
-  return info;
+  (void)getTypeAndCreateCounterForPotentialAliasVar(var, &isAlias);
+  varInstr->setContainsAliasComponent(isAlias);
+
+  assert(astDecls[var].instr == nullptr);
+  astDecls[var].instr = varInstr;
+
+  return varInstr;
 }
 
-SpirvEvalInfo DeclResultIdMapper::createExternVar(const VarDecl *var) {
+SpirvVariable *DeclResultIdMapper::createExternVar(const VarDecl *var) {
   auto storageClass = spv::StorageClass::UniformConstant;
   auto rule = SpirvLayoutRule::Void;
   bool isACRWSBuffer = false; // Whether is {Append|Consume|RW}StructuredBuffer
@@ -402,7 +504,7 @@ SpirvEvalInfo DeclResultIdMapper::createExternVar(const VarDecl *var) {
   if (var->getAttr<HLSLGroupSharedAttr>()) {
     // For CS groupshared variables
     storageClass = spv::StorageClass::Workgroup;
-  } else if (TypeTranslator::isResourceType(var)) {
+  } else if (isResourceType(var)) {
     // See through the possible outer arrays
     QualType resourceType = var->getType();
     while (resourceType->isArrayType()) {
@@ -433,12 +535,13 @@ SpirvEvalInfo DeclResultIdMapper::createExternVar(const VarDecl *var) {
     if (astDecls.count(var) == 0)
       createGlobalsCBuffer(var);
 
-    return astDecls[var].info;
+    assert(isa<SpirvVariable>(astDecls[var].instr));
+    return cast<SpirvVariable>(astDecls[var].instr);
   }
 
-  uint32_t varType = typeTranslator.translateType(var->getType(), rule);
-
   // Require corresponding capability for accessing 16-bit data.
+  // TODO(ehsan): This should be removed and moved to a pass.
+  /*
   if (storageClass == spv::StorageClass::Uniform &&
       spirvOptions.enable16BitTypes &&
       typeTranslator.isOrContains16BitType(var->getType())) {
@@ -446,61 +549,41 @@ SpirvEvalInfo DeclResultIdMapper::createExternVar(const VarDecl *var) {
                             "16-bit types in resource", var->getLocation());
     theBuilder.requireCapability(spv::Capability::StorageUniformBufferBlock16);
   }
+  */
 
-  const uint32_t id = theBuilder.addModuleVar(varType, storageClass,
-                                              var->getName(), llvm::None);
-  const auto info =
-      SpirvEvalInfo(id).setStorageClass(storageClass).setLayoutRule(rule);
+  const auto type = var->getType();
+  const auto loc = var->getLocation();
+  SpirvVariable *varInstr = spvBuilder.addModuleVar(
+      type, storageClass, var->getName(), llvm::None, loc);
+  varInstr->setLayoutRule(rule);
+  DeclSpirvInfo info(varInstr);
   astDecls[var] = info;
 
   // Variables in Workgroup do not need descriptor decorations.
   if (storageClass == spv::StorageClass::Workgroup)
-    return info;
+    return varInstr;
 
   const auto *regAttr = getResourceBinding(var);
   const auto *bindingAttr = var->getAttr<VKBindingAttr>();
   const auto *counterBindingAttr = var->getAttr<VKCounterBindingAttr>();
 
-  resourceVars.emplace_back(id, var->getLocation(), regAttr, bindingAttr,
+  resourceVars.emplace_back(varInstr, loc, regAttr, bindingAttr,
                             counterBindingAttr);
 
   if (const auto *inputAttachment = var->getAttr<VKInputAttachmentIndexAttr>())
-    theBuilder.decorateInputAttachmentIndex(id, inputAttachment->getIndex());
+    spvBuilder.decorateInputAttachmentIndex(varInstr,
+                                            inputAttachment->getIndex(), loc);
 
   if (isACRWSBuffer) {
     // For {Append|Consume|RW}StructuredBuffer, we need to always create another
     // variable for its associated counter.
-    createCounterVar(var, id, /*isAlias=*/false);
+    createCounterVar(var, varInstr, /*isAlias=*/false);
   }
 
-  return info;
+  return varInstr;
 }
 
-uint32_t DeclResultIdMapper::getMatrixStructType(const VarDecl *matVar,
-                                                 spv::StorageClass sc,
-                                                 SpirvLayoutRule rule) {
-  const auto matType = matVar->getType();
-  assert(isMxNMatrix(matType));
-
-  auto &context = *theBuilder.getSPIRVContext();
-  llvm::SmallVector<const Decoration *, 4> decorations;
-  const bool isRowMajor = typeTranslator.isRowMajorMatrix(matType);
-
-  uint32_t stride;
-  (void)typeTranslator.getAlignmentAndSize(matType, rule, &stride);
-  decorations.push_back(Decoration::getOffset(context, 0, 0));
-  decorations.push_back(Decoration::getMatrixStride(context, stride, 0));
-  decorations.push_back(isRowMajor ? Decoration::getColMajor(context, 0)
-                                   : Decoration::getRowMajor(context, 0));
-  decorations.push_back(Decoration::getBlock(context));
-
-  // Get the type for the wrapping struct
-  const std::string structName = "type." + matVar->getName().str();
-  return theBuilder.getStructType({typeTranslator.translateType(matType)},
-                                  structName, {}, decorations);
-}
-
-uint32_t DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
+SpirvVariable *DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
     const DeclContext *decl, int arraySize, const ContextUsageKind usageKind,
     llvm::StringRef typeName, llvm::StringRef varName) {
   // cbuffers are translated into OpTypeStruct with Block decoration.
@@ -515,24 +598,11 @@ uint32_t DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
   const bool forGlobals = usageKind == ContextUsageKind::Globals;
   const bool forPC = usageKind == ContextUsageKind::PushConstant;
 
-  auto &context = *theBuilder.getSPIRVContext();
-  const SpirvLayoutRule layoutRule =
-      (forCBuffer || forGlobals)
-          ? spirvOptions.cBufferLayoutRule
-          : (forTBuffer ? spirvOptions.tBufferLayoutRule
-                        : spirvOptions.sBufferLayoutRule);
-  const auto *blockDec = forTBuffer ? Decoration::getBufferBlock(context)
-                                    : Decoration::getBlock(context);
-
   const llvm::SmallVector<const Decl *, 4> &declGroup =
-      typeTranslator.collectDeclsInDeclContext(decl);
-  auto decorations = typeTranslator.getLayoutDecorations(declGroup, layoutRule);
-  decorations.push_back(blockDec);
+      collectDeclsInDeclContext(decl);
 
   // Collect the type and name for each field
-  llvm::SmallVector<uint32_t, 4> fieldTypes;
-  llvm::SmallVector<llvm::StringRef, 4> fieldNames;
-  uint32_t fieldIndex = 0;
+  llvm::SmallVector<HybridStructType::FieldInfo, 4> fields;
   for (const auto *subDecl : declGroup) {
     // The field can only be FieldDecl (for normal structs) or VarDecl (for
     // HLSLBufferDecls).
@@ -543,10 +613,10 @@ uint32_t DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
     // We don't need it here.
     auto varType = declDecl->getType();
     varType.removeLocalConst();
+    HybridStructType::FieldInfo info(varType, nullptr, declDecl->getName());
 
-    fieldTypes.push_back(typeTranslator.translateType(varType, layoutRule));
-    fieldNames.push_back(declDecl->getName());
-
+    /*
+    // TODO(ehsan): This should be removed and moved to a pass.
     // Require corresponding capability for accessing 16-bit data.
     if (spirvOptions.enable16BitTypes &&
         typeTranslator.isOrContains16BitType(varType)) {
@@ -559,48 +629,51 @@ uint32_t DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
               : forPC ? spv::Capability::StoragePushConstant16
                       : spv::Capability::StorageUniformBufferBlock16);
     }
-
-    // tbuffer/TextureBuffers are non-writable SSBOs. OpMemberDecorate
-    // NonWritable must be applied to all fields.
-    if (forTBuffer) {
-      decorations.push_back(Decoration::getNonWritable(
-          *theBuilder.getSPIRVContext(), fieldIndex));
-    }
-    ++fieldIndex;
+    */
   }
 
   // Get the type for the whole struct
-  uint32_t resultType =
-      theBuilder.getStructType(fieldTypes, typeName, fieldNames, decorations);
+  // tbuffer/TextureBuffers are non-writable SSBOs.
+  const SpirvType *resultType = spvContext.getHybridStructType(
+      fields, typeName, /*isReadOnly*/ forTBuffer,
+      forTBuffer ? HybridStructType::InterfaceType::StorageBuffer
+                 : HybridStructType::InterfaceType::UniformBuffer);
 
   // Make an array if requested.
   if (arraySize > 0) {
-    resultType = theBuilder.getArrayType(
-        resultType, theBuilder.getConstantUint32(arraySize));
+    resultType = spvContext.getArrayType(resultType, arraySize);
   } else if (arraySize == -1) {
     // Runtime arrays of cbuffer/tbuffer needs additional capability.
-    theBuilder.addExtension(Extension::EXT_descriptor_indexing,
+    spvBuilder.addExtension(Extension::EXT_descriptor_indexing,
                             "runtime array of resources", {});
-    theBuilder.requireCapability(spv::Capability::RuntimeDescriptorArrayEXT);
-    resultType = theBuilder.getRuntimeArrayType(resultType);
+    spvBuilder.requireCapability(spv::Capability::RuntimeDescriptorArrayEXT);
+    resultType = spvContext.getRuntimeArrayType(resultType);
   }
 
   // Register the <type-id> for this decl
-  ctBufferPCTypeIds[decl] = resultType;
+  ctBufferPCTypes[decl] = resultType;
 
   const auto sc =
       forPC ? spv::StorageClass::PushConstant : spv::StorageClass::Uniform;
 
   // Create the variable for the whole struct / struct array.
-  return theBuilder.addModuleVar(resultType, sc, varName);
+  SpirvVariable *var = spvBuilder.addModuleVar(resultType, sc, varName);
+  const SpirvLayoutRule layoutRule =
+      (forCBuffer || forGlobals)
+          ? spirvOptions.cBufferLayoutRule
+          : (forTBuffer ? spirvOptions.tBufferLayoutRule
+                        : spirvOptions.sBufferLayoutRule);
+
+  var->setLayoutRule(layoutRule);
+  return var;
 }
 
-uint32_t DeclResultIdMapper::createCTBuffer(const HLSLBufferDecl *decl) {
+SpirvVariable *DeclResultIdMapper::createCTBuffer(const HLSLBufferDecl *decl) {
   const auto usageKind =
       decl->isCBuffer() ? ContextUsageKind::CBuffer : ContextUsageKind::TBuffer;
   const std::string structName = "type." + decl->getName().str();
   // The front-end does not allow arrays of cbuffer/tbuffer.
-  const uint32_t bufferVar = createStructOrStructArrayVarOfExplicitLayout(
+  SpirvVariable *bufferVar = createStructOrStructArrayVarOfExplicitLayout(
       decl, /*arraySize*/ 0, usageKind, structName, decl->getName());
 
   // We still register all VarDecls seperately here. All the VarDecls are
@@ -609,16 +682,11 @@ uint32_t DeclResultIdMapper::createCTBuffer(const HLSLBufferDecl *decl) {
   // OpAccessChain.
   int index = 0;
   for (const auto *subDecl : decl->decls()) {
-    if (TypeTranslator::shouldSkipInStructLayout(subDecl))
+    if (shouldSkipInStructLayout(subDecl))
       continue;
 
     const auto *varDecl = cast<VarDecl>(subDecl);
-    astDecls[varDecl] =
-        SpirvEvalInfo(bufferVar)
-            .setStorageClass(spv::StorageClass::Uniform)
-            .setLayoutRule(decl->isCBuffer() ? spirvOptions.cBufferLayoutRule
-                                             : spirvOptions.tBufferLayoutRule);
-    astDecls[varDecl].indexInCTBuffer = index++;
+    astDecls[varDecl] = DeclSpirvInfo(bufferVar, index++);
   }
   resourceVars.emplace_back(
       bufferVar, decl->getLocation(), getResourceBinding(decl),
@@ -627,7 +695,7 @@ uint32_t DeclResultIdMapper::createCTBuffer(const HLSLBufferDecl *decl) {
   return bufferVar;
 }
 
-uint32_t DeclResultIdMapper::createCTBuffer(const VarDecl *decl) {
+SpirvVariable *DeclResultIdMapper::createCTBuffer(const VarDecl *decl) {
   const RecordType *recordType = nullptr;
   int arraySize = 0;
 
@@ -659,15 +727,11 @@ uint32_t DeclResultIdMapper::createCTBuffer(const VarDecl *decl) {
   const std::string structName = "type." + std::string(ctBufferName) +
                                  recordType->getDecl()->getName().str();
 
-  const uint32_t bufferVar = createStructOrStructArrayVarOfExplicitLayout(
+  SpirvVariable *bufferVar = createStructOrStructArrayVarOfExplicitLayout(
       recordType->getDecl(), arraySize, usageKind, structName, decl->getName());
 
   // We register the VarDecl here.
-  astDecls[decl] =
-      SpirvEvalInfo(bufferVar)
-          .setStorageClass(spv::StorageClass::Uniform)
-          .setLayoutRule(context->isCBuffer() ? spirvOptions.cBufferLayoutRule
-                                              : spirvOptions.tBufferLayoutRule);
+  astDecls[decl] = DeclSpirvInfo(bufferVar);
   resourceVars.emplace_back(
       bufferVar, decl->getLocation(), getResourceBinding(context),
       decl->getAttr<VKBindingAttr>(), decl->getAttr<VKCounterBindingAttr>());
@@ -675,21 +739,20 @@ uint32_t DeclResultIdMapper::createCTBuffer(const VarDecl *decl) {
   return bufferVar;
 }
 
-uint32_t DeclResultIdMapper::createPushConstant(const VarDecl *decl) {
+SpirvVariable *DeclResultIdMapper::createPushConstant(const VarDecl *decl) {
   // The front-end errors out if non-struct type push constant is used.
   const auto *recordType = decl->getType()->getAs<RecordType>();
   assert(recordType);
 
   const std::string structName =
       "type.PushConstant." + recordType->getDecl()->getName().str();
-  const uint32_t var = createStructOrStructArrayVarOfExplicitLayout(
+  SpirvVariable *var = createStructOrStructArrayVarOfExplicitLayout(
       recordType->getDecl(), /*arraySize*/ 0, ContextUsageKind::PushConstant,
       structName, decl->getName());
 
   // Register the VarDecl
-  astDecls[decl] = SpirvEvalInfo(var)
-                       .setStorageClass(spv::StorageClass::PushConstant)
-                       .setLayoutRule(spirvOptions.sBufferLayoutRule);
+  astDecls[decl] = DeclSpirvInfo(var);
+
   // Do not push this variable into resourceVars since it does not need
   // descriptor set.
 
@@ -701,7 +764,7 @@ void DeclResultIdMapper::createGlobalsCBuffer(const VarDecl *var) {
     return;
 
   const auto *context = var->getTranslationUnitDecl();
-  const uint32_t globals = createStructOrStructArrayVarOfExplicitLayout(
+  SpirvVariable *globals = createStructOrStructArrayVarOfExplicitLayout(
       context, /*arraySize*/ 0, ContextUsageKind::Globals, "type.$Globals",
       "$Globals");
 
@@ -709,7 +772,7 @@ void DeclResultIdMapper::createGlobalsCBuffer(const VarDecl *var) {
                             nullptr);
 
   uint32_t index = 0;
-  for (const auto *decl : typeTranslator.collectDeclsInDeclContext(context))
+  for (const auto *decl : collectDeclsInDeclContext(context))
     if (const auto *varDecl = dyn_cast<VarDecl>(decl)) {
       if (!spirvOptions.noWarnIgnoredFeatures) {
         if (const auto *init = varDecl->getInit())
@@ -726,33 +789,31 @@ void DeclResultIdMapper::createGlobalsCBuffer(const VarDecl *var) {
         return;
       }
 
-      astDecls[varDecl] = SpirvEvalInfo(globals)
-                              .setStorageClass(spv::StorageClass::Uniform)
-                              .setLayoutRule(spirvOptions.cBufferLayoutRule);
-      astDecls[varDecl].indexInCTBuffer = index++;
+      astDecls[varDecl] = DeclSpirvInfo(globals, index++);
     }
 }
 
-uint32_t DeclResultIdMapper::getOrRegisterFnResultId(const FunctionDecl *fn) {
-  if (const auto *info = getDeclSpirvInfo(fn))
-    return info->info;
-
-  auto &info = astDecls[fn].info;
+SpirvFunction *DeclResultIdMapper::getOrRegisterFn(const FunctionDecl *fn) {
+  // Return it if it's already been created.
+  auto it = astFunctionDecls.find(fn);
+  if (it != astFunctionDecls.end()) {
+    return it->second;
+  }
 
   bool isAlias = false;
+  (void)getTypeAndCreateCounterForPotentialAliasVar(fn, &isAlias);
 
-  (void)getTypeAndCreateCounterForPotentialAliasVar(fn, &isAlias, &info);
+  SpirvFunction *spirvFunction = spvBuilder.createFunction(
+      fn->getReturnType(), fn->getLocation(), fn->getName(), isAlias);
 
-  const uint32_t id = theBuilder.getSPIRVContext()->takeNextId();
-  info.setResultId(id);
   // No need to dereference to get the pointer. Function returns that are
   // stand-alone aliases are already pointers to values. All other cases should
   // be normal rvalues.
-  if (!isAlias ||
-      !TypeTranslator::isAKindOfStructuredOrByteBuffer(fn->getReturnType()))
-    info.setRValue();
+  if (!isAlias || !isAKindOfStructuredOrByteBuffer(fn->getReturnType()))
+    spirvFunction->setRValue();
 
-  return id;
+  astFunctionDecls[fn] = spirvFunction;
+  return spirvFunction;
 }
 
 const CounterIdAliasPair *DeclResultIdMapper::getCounterIdAliasPair(
@@ -788,12 +849,14 @@ DeclResultIdMapper::getCounterVarFields(const DeclaratorDecl *decl) {
 }
 
 void DeclResultIdMapper::registerSpecConstant(const VarDecl *decl,
-                                              uint32_t specConstant) {
-  astDecls[decl].info.setResultId(specConstant).setRValue().setSpecConstant();
+                                              SpirvInstruction *specConstant) {
+  specConstant->setSpecConstant();
+  specConstant->setRValue();
+  astDecls[decl] = DeclSpirvInfo(specConstant);
 }
 
 void DeclResultIdMapper::createCounterVar(
-    const DeclaratorDecl *decl, uint32_t declId, bool isAlias,
+    const DeclaratorDecl *decl, SpirvInstruction *declInstr, bool isAlias,
     const llvm::SmallVector<uint32_t, 4> *indices) {
   std::string counterName = "counter.var." + decl->getName().str();
   if (indices) {
@@ -802,6 +865,7 @@ void DeclResultIdMapper::createCounterVar(
       counterName += "." + std::to_string(index);
   }
 
+  const SpirvType *counterType = spvContext.getACSBufferCounterType();
   uint32_t counterType = typeTranslator.getACSBufferCounter();
   // {RW|Append|Consume}StructuredBuffer are all in Uniform storage class.
   // Alias counter variables should be created into the Private storage class.
@@ -811,27 +875,27 @@ void DeclResultIdMapper::createCounterVar(
   if (isAlias) {
     // Apply an extra level of pointer for alias counter variable
     counterType =
-        theBuilder.getPointerType(counterType, spv::StorageClass::Uniform);
+        spvContext.getPointerType(counterType, spv::StorageClass::Uniform);
   }
 
-  const uint32_t counterId =
-      theBuilder.addModuleVar(counterType, sc, counterName);
+  SpirvVariable *counterInstr = spvBuilder.addModuleVar(
+      counterType, spv::StorageClass::Uniform, counterName);
 
   if (!isAlias) {
     // Non-alias counter variables should be put in to resourceVars so that
     // descriptors can be allocated for them.
-    resourceVars.emplace_back(counterId, decl->getLocation(),
+    resourceVars.emplace_back(counterInstr, decl->getLocation(),
                               getResourceBinding(decl),
                               decl->getAttr<VKBindingAttr>(),
                               decl->getAttr<VKCounterBindingAttr>(), true);
-    assert(declId);
-    theBuilder.decorateCounterBufferId(declId, counterId);
+    assert(declInstr);
+    spvBuilder.decorateCounterBuffer(declInstr, counterInstr);
   }
 
   if (indices)
-    fieldCounterVars[decl].append(*indices, counterId);
+    fieldCounterVars[decl].append(*indices, counterInstr);
   else
-    counterVars[decl] = {counterId, isAlias};
+    counterVars[decl] = {counterInstr, isAlias};
 }
 
 void DeclResultIdMapper::createFieldCounterVars(
@@ -858,29 +922,23 @@ void DeclResultIdMapper::createFieldCounterVars(
   }
 }
 
-uint32_t
-DeclResultIdMapper::getCTBufferPushConstantTypeId(const DeclContext *decl) {
-  const auto found = ctBufferPCTypeIds.find(decl);
-  assert(found != ctBufferPCTypeIds.end());
+const SpirvType *
+DeclResultIdMapper::getCTBufferPushConstantType(const DeclContext *decl) {
+  const auto found = ctBufferPCTypes.find(decl);
+  assert(found != ctBufferPCTypes.end());
   return found->second;
 }
 
-std::vector<uint32_t> DeclResultIdMapper::collectStageVars() const {
-  std::vector<uint32_t> vars;
+std::vector<SpirvVariable *> DeclResultIdMapper::collectStageVars() const {
+  std::vector<SpirvVariable *> vars;
 
   for (auto var : glPerVertex.getStageInVars())
     vars.push_back(var);
   for (auto var : glPerVertex.getStageOutVars())
     vars.push_back(var);
 
-  llvm::DenseSet<uint32_t> seenVars;
-  for (const auto &var : stageVars) {
-    const auto id = var.getSpirvId();
-    if (seenVars.count(id) == 0) {
-      vars.push_back(id);
-      seenVars.insert(id);
-    }
-  }
+  for (const auto &var : stageVars)
+    vars.push_back(var.getSpirvInstr());
 
   return vars;
 }
@@ -1049,9 +1107,9 @@ bool DeclResultIdMapper::finalizeStageIOLocations(bool forInput) {
       }
       locSet.useLoc(loc, idx);
 
-      theBuilder.decorateLocation(var.getSpirvId(), loc);
+      spvBuilder.decorateLocation(var.getSpirvInstr(), loc);
       if (var.getIndexAttr())
-        theBuilder.decorateIndex(var.getSpirvId(), idx);
+        spvBuilder.decorateIndex(var.getSpirvInstr(), idx);
     }
 
     return noError;
@@ -1079,7 +1137,7 @@ bool DeclResultIdMapper::finalizeStageIOLocations(bool forInput) {
     // We should special rules for SV_Target: the location number comes from the
     // semantic string index.
     if (semaInfo.isTarget()) {
-      theBuilder.decorateLocation(var.getSpirvId(), semaInfo.index);
+      spvBuilder.decorateLocation(var.getSpirvInstr(), semaInfo.index);
       locSet.useLoc(semaInfo.index);
     } else {
       vars.push_back(&var);
@@ -1101,7 +1159,7 @@ bool DeclResultIdMapper::finalizeStageIOLocations(bool forInput) {
   }
 
   for (const auto *var : vars)
-    theBuilder.decorateLocation(var->getSpirvId(),
+    spvBuilder.decorateLocation(var->getSpirvInstr(),
                                 locSet.useNextLocs(var->getLocationCount()));
 
   return true;
@@ -1235,7 +1293,7 @@ bool DeclResultIdMapper::decorateResourceBindings() {
                       var.getSourceLocation());
             return false;
           }
-          theBuilder.decorateDSetBinding(var.getSpirvId(), setNo, bindNo);
+          spvBuilder.decorateDSetBinding(var.getSpirvInstr(), setNo, bindNo);
         }
       } else {
         emitError(
@@ -1251,11 +1309,11 @@ bool DeclResultIdMapper::decorateResourceBindings() {
 
   // Decorates the given varId of the given category with set number
   // setNo, binding number bindingNo. Ignores overlaps.
-  const auto tryToDecorate = [this, &bindingSet](const uint32_t varId,
+  const auto tryToDecorate = [this, &bindingSet](SpirvInstruction *var,
                                                  const uint32_t setNo,
                                                  const uint32_t bindingNo) {
     bindingSet.useBinding(bindingNo, setNo);
-    theBuilder.decorateDSetBinding(varId, setNo, bindingNo);
+    spvBuilder.decorateDSetBinding(var, setNo, bindingNo);
   };
 
   for (const auto &var : resourceVars) {
@@ -1268,12 +1326,12 @@ bool DeclResultIdMapper::decorateResourceBindings() {
         else if (const auto *reg = var.getRegister())
           set = reg->RegisterSpace;
 
-        tryToDecorate(var.getSpirvId(), set, vkCBinding->getBinding());
+        tryToDecorate(var.getSpirvInstr(), set, vkCBinding->getBinding());
       }
     } else {
       if (const auto *vkBinding = var.getBinding()) {
         // Process m1
-        tryToDecorate(var.getSpirvId(), vkBinding->getSet(),
+        tryToDecorate(var.getSpirvInstr(), vkBinding->getSet(),
                       vkBinding->getBinding());
       }
     }
@@ -1314,7 +1372,7 @@ bool DeclResultIdMapper::decorateResourceBindings() {
           llvm_unreachable("unknown register type found");
         }
 
-        tryToDecorate(var.getSpirvId(), set, binding);
+        tryToDecorate(var.getSpirvInstr(), set, binding);
       }
 
   for (const auto &var : resourceVars) {
@@ -1327,18 +1385,18 @@ bool DeclResultIdMapper::decorateResourceBindings() {
         else if (const auto *reg = var.getRegister())
           set = reg->RegisterSpace;
 
-        theBuilder.decorateDSetBinding(var.getSpirvId(), set,
+        spvBuilder.decorateDSetBinding(var.getSpirvInstr(), set,
                                        bindingSet.useNextBinding(set));
       }
     } else if (!var.getBinding()) {
       const auto *reg = var.getRegister();
       if (reg && reg->isSpaceOnly()) {
         const uint32_t set = reg->RegisterSpace;
-        theBuilder.decorateDSetBinding(var.getSpirvId(), set,
+        spvBuilder.decorateDSetBinding(var.getSpirvInstr(), set,
                                        bindingSet.useNextBinding(set));
       } else if (!reg) {
         // Process m3
-        theBuilder.decorateDSetBinding(var.getSpirvId(), 0,
+        spvBuilder.decorateDSetBinding(var.getSpirvInstr(), 0,
                                        bindingSet.useNextBinding(0));
       }
     }
@@ -1347,13 +1405,11 @@ bool DeclResultIdMapper::decorateResourceBindings() {
   return true;
 }
 
-bool DeclResultIdMapper::createStageVars(const hlsl::SigPoint *sigPoint,
-                                         const NamedDecl *decl, bool asInput,
-                                         QualType type, uint32_t arraySize,
-                                         const llvm::StringRef namePrefix,
-                                         llvm::Optional<uint32_t> invocationId,
-                                         uint32_t *value, bool noWriteBack,
-                                         SemanticInfo *inheritSemantic) {
+bool DeclResultIdMapper::createStageVars(
+    const hlsl::SigPoint *sigPoint, const NamedDecl *decl, bool asInput,
+    QualType type, uint32_t arraySize, const llvm::StringRef namePrefix,
+    llvm::Optional<SpirvInstruction *> invocationId, SpirvInstruction **value,
+    bool noWriteBack, SemanticInfo *inheritSemantic) {
   // invocationId should only be used for handling HS per-vertex output.
   if (invocationId.hasValue()) {
     assert(shaderModel.IsHS() && arraySize != 0 && !asInput);
@@ -1366,8 +1422,8 @@ bool DeclResultIdMapper::createStageVars(const hlsl::SigPoint *sigPoint,
     return true;
   }
 
-  // The type the variable is evaluated as for SPIR-V.
-  QualType evalType = type;
+  // uint32_t typeId = typeTranslator.translateType(type);
+  const SpirvType *spvType = nullptr;
 
   // We have several cases regarding HLSL semantics to handle here:
   // * If the currrent decl inherits a semantic from some enclosing entity,
@@ -2645,8 +2701,8 @@ DeclResultIdMapper::getStorageClassForSigPoint(const hlsl::SigPoint *sigPoint) {
   return sc;
 }
 
-uint32_t DeclResultIdMapper::getTypeAndCreateCounterForPotentialAliasVar(
-    const DeclaratorDecl *decl, bool *shouldBeAlias, SpirvEvalInfo *info) {
+QualType DeclResultIdMapper::getTypeAndCreateCounterForPotentialAliasVar(
+    const DeclaratorDecl *decl, bool *shouldBeAlias) {
   if (const auto *varDecl = dyn_cast<VarDecl>(decl)) {
     // This method is only intended to be used to create SPIR-V variables in the
     // Function or Private storage class.
@@ -2661,23 +2717,20 @@ uint32_t DeclResultIdMapper::getTypeAndCreateCounterForPotentialAliasVar(
     // For ConstantBuffer and TextureBuffer
     if (buffer->isConstantBufferView())
       genAlias = true;
-  } else if (TypeTranslator::isOrContainsAKindOfStructuredOrByteBuffer(type)) {
+  } else if (isOrContainsAKindOfStructuredOrByteBuffer(type)) {
     genAlias = true;
   }
 
+  // Return via parameter whether alias was generated.
   if (shouldBeAlias)
     *shouldBeAlias = genAlias;
 
   if (genAlias) {
     needsLegalization = true;
-
     createCounterVarForDecl(decl);
-
-    if (info)
-      info->setContainsAliasComponent(true);
   }
 
-  return typeTranslator.translateType(type);
+  return type;
 }
 
 } // end namespace spirv

+ 119 - 110
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -21,6 +21,7 @@
 #include "clang/AST/Attr.h"
 #include "clang/SPIRV/FeatureManager.h"
 #include "clang/SPIRV/ModuleBuilder.h"
+#include "clang/SPIRV/SpirvBuilder.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/Optional.h"
 #include "llvm/ADT/SmallVector.h"
@@ -54,10 +55,10 @@ struct SemanticInfo {
 class StageVar {
 public:
   inline StageVar(const hlsl::SigPoint *sig, SemanticInfo semaInfo,
-                  const VKBuiltInAttr *builtin, uint32_t type,
+                  const VKBuiltInAttr *builtin, const SpirvType *spvType,
                   uint32_t locCount)
       : sigPoint(sig), semanticInfo(std::move(semaInfo)), builtinAttr(builtin),
-        typeId(type), valueId(0), isBuiltin(false),
+        type(spvType), value(nullptr), isBuiltin(false),
         storageClass(spv::StorageClass::Max), location(nullptr),
         locationCount(locCount) {
     isBuiltin = builtinAttr != nullptr;
@@ -67,10 +68,10 @@ public:
   const SemanticInfo &getSemanticInfo() const { return semanticInfo; }
   std::string getSemanticStr() const;
 
-  uint32_t getSpirvTypeId() const { return typeId; }
+  const SpirvType *getSpirvType() const { return type; }
 
-  uint32_t getSpirvId() const { return valueId; }
-  void setSpirvId(uint32_t id) { valueId = id; }
+  SpirvVariable *getSpirvInstr() const { return value; }
+  void setSpirvInstr(SpirvVariable *spvInstr) { value = spvInstr; }
 
   const VKBuiltInAttr *getBuiltInAttr() const { return builtinAttr; }
 
@@ -96,10 +97,10 @@ private:
   SemanticInfo semanticInfo;
   /// SPIR-V BuiltIn attribute.
   const VKBuiltInAttr *builtinAttr;
-  /// SPIR-V <type-id>.
-  uint32_t typeId;
-  /// SPIR-V <result-id>.
-  uint32_t valueId;
+  /// SPIR-V type.
+  const SpirvType *type;
+  /// SPIR-V instruction.
+  SpirvVariable *value;
   /// Indicates whether this stage variable should be a SPIR-V builtin.
   bool isBuiltin;
   /// SPIR-V storage class this stage variable belongs to.
@@ -114,13 +115,13 @@ private:
 
 class ResourceVar {
 public:
-  ResourceVar(uint32_t id, SourceLocation loc,
+  ResourceVar(SpirvVariable *var, SourceLocation loc,
               const hlsl::RegisterAssignment *r, const VKBindingAttr *b,
               const VKCounterBindingAttr *cb, bool counter = false)
-      : varId(id), srcLoc(loc), reg(r), binding(b), counterBinding(cb),
+      : variable(var), srcLoc(loc), reg(r), binding(b), counterBinding(cb),
         isCounterVar(counter) {}
 
-  uint32_t getSpirvId() const { return varId; }
+  SpirvVariable *getSpirvInstr() const { return variable; }
   SourceLocation getSourceLocation() const { return srcLoc; }
   const hlsl::RegisterAssignment *getRegister() const { return reg; }
   const VKBindingAttr *getBinding() const { return binding; }
@@ -130,7 +131,7 @@ public:
   }
 
 private:
-  uint32_t varId;                             ///< <result-id>
+  SpirvVariable *variable;                    ///< The variable
   SourceLocation srcLoc;                      ///< Source location
   const hlsl::RegisterAssignment *reg;        ///< HLSL register assignment
   const VKBindingAttr *binding;               ///< Vulkan binding assignment
@@ -138,24 +139,25 @@ private:
   bool isCounterVar;                          ///< Couter variable or not
 };
 
-/// A (<result-id>, is-alias-or-not) pair for counter variables
+/// A (instruction-pointer, is-alias-or-not) pair for counter variables
 class CounterIdAliasPair {
 public:
   /// Default constructor to satisfy llvm::DenseMap
-  CounterIdAliasPair() : resultId(0), isAlias(false) {}
-  CounterIdAliasPair(uint32_t id, bool alias) : resultId(id), isAlias(alias) {}
+  CounterIdAliasPair() : counterVar(nullptr), isAlias(false) {}
+  CounterIdAliasPair(SpirvVariable *var, bool alias)
+      : counterVar(var), isAlias(alias) {}
 
   /// Returns the pointer to the counter variable. Dereferences first if this is
   /// an alias to a counter variable.
-  uint32_t get(ModuleBuilder &builder, TypeTranslator &translator) const;
+  SpirvInstruction *get(SpirvBuilder &builder, SpirvContext &spvContext) const;
 
   /// Stores the counter variable's pointer in srcPair to the curent counter
   /// variable. The current counter variable must be an alias.
-  inline void assign(const CounterIdAliasPair &srcPair, ModuleBuilder &builder,
-                     TypeTranslator &translator) const;
+  inline void assign(const CounterIdAliasPair &srcPair, SpirvBuilder &,
+                     SpirvContext &) const;
 
 private:
-  uint32_t resultId;
+  SpirvVariable *counterVar;
   /// Note: legalization specific code
   bool isAlias;
 };
@@ -205,7 +207,8 @@ public:
   CounterVarFields() = default;
 
   /// Registers a field's associated counter.
-  void append(const llvm::SmallVector<uint32_t, 4> &indices, uint32_t counter) {
+  void append(const llvm::SmallVector<uint32_t, 4> &indices,
+              SpirvVariable *counter) {
     fields.emplace_back(indices, counter);
   }
 
@@ -220,17 +223,17 @@ public:
   /// This first overload is for assigning a struct as whole: we need to update
   /// all the associated counters in the target struct. This second overload is
   /// for assigning a potentially nested struct.
-  bool assign(const CounterVarFields &srcFields, ModuleBuilder &builder,
-              TypeTranslator &translator) const;
+  bool assign(const CounterVarFields &srcFields, SpirvBuilder &,
+              SpirvContext &) const;
   bool assign(const CounterVarFields &srcFields,
               const llvm::SmallVector<uint32_t, 4> &dstPrefix,
-              const llvm::SmallVector<uint32_t, 4> &srcPrefix,
-              ModuleBuilder &builder, TypeTranslator &translator) const;
+              const llvm::SmallVector<uint32_t, 4> &srcPrefix, SpirvBuilder &,
+              SpirvContext &) const;
 
 private:
   struct IndexCounterPair {
     IndexCounterPair(const llvm::SmallVector<uint32_t, 4> &idx,
-                     uint32_t counter)
+                     SpirvVariable *counter)
         : indices(idx), counterVar(counter, true) {}
 
     llvm::SmallVector<uint32_t, 4> indices; ///< Index vector
@@ -259,7 +262,8 @@ private:
 class DeclResultIdMapper {
 public:
   inline DeclResultIdMapper(const hlsl::ShaderModel &stage, ASTContext &context,
-                            ModuleBuilder &builder, SPIRVEmitter &emitter,
+                            SpirvContext &spirvContext, ModuleBuilder &builder,
+                            SpirvBuilder &spirvBuilder, SPIRVEmitter &emitter,
                             TypeTranslator &translator,
                             FeatureManager &features,
                             const SpirvCodeGenOptions &spirvOptions);
@@ -276,23 +280,24 @@ public:
   ///
   /// Note that the control point stage output variable of HS should be created
   /// by the other overload.
-  bool createStageOutputVar(const DeclaratorDecl *decl, uint32_t storedValue,
-                            bool forPCF);
+  bool createStageOutputVar(const DeclaratorDecl *decl,
+                            SpirvInstruction *storedValue, bool forPCF);
   /// \brief Overload for handling HS control point stage ouput variable.
   bool createStageOutputVar(const DeclaratorDecl *decl, uint32_t arraySize,
-                            uint32_t invocationId, uint32_t storedValue);
+                            SpirvInstruction *invocationId,
+                            SpirvInstruction *storedValue);
 
   /// \brief Creates the stage input variables by parsing the semantics attached
   /// to the given function's parameter and returns true on success. SPIR-V
   /// instructions will also be generated to load the contents from the input
   /// variables and composite them into one and write to *loadedValue. forPCF
   /// should be set to true for handling decls in patch constant function.
-  bool createStageInputVar(const ParmVarDecl *paramDecl, uint32_t *loadedValue,
-                           bool forPCF);
+  bool createStageInputVar(const ParmVarDecl *paramDecl,
+                           SpirvInstruction **loadedValue, bool forPCF);
 
   /// \brief Creates a function-scope paramter in the current function and
-  /// returns its <result-id>.
-  uint32_t createFnParam(const ParmVarDecl *param);
+  /// returns its instruction.
+  SpirvFunctionParameter *createFnParam(const ParmVarDecl *param);
 
   /// \brief Creates the counter variable associated with the given param.
   /// This is meant to be used for forward-declared functions and this objects
@@ -302,15 +307,16 @@ public:
   inline void createFnParamCounterVar(const VarDecl *param);
 
   /// \brief Creates a function-scope variable in the current function and
-  /// returns its <result-id>.
-  SpirvEvalInfo createFnVar(const VarDecl *var, llvm::Optional<uint32_t> init);
+  /// returns its instruction.
+  SpirvVariable *createFnVar(const VarDecl *var,
+                             llvm::Optional<SpirvInstruction *> init);
 
-  /// \brief Creates a file-scope variable and returns its <result-id>.
-  SpirvEvalInfo createFileVar(const VarDecl *var,
-                              llvm::Optional<uint32_t> init);
+  /// \brief Creates a file-scope variable and returns its instruction.
+  SpirvVariable *createFileVar(const VarDecl *var,
+                               llvm::Optional<SpirvInstruction *> init);
 
-  /// \brief Creates an external-visible variable and returns its <result-id>.
-  SpirvEvalInfo createExternVar(const VarDecl *var);
+  /// \brief Creates an external-visible variable and returns its instruction.
+  SpirvVariable *createExternVar(const VarDecl *var);
 
   /// \brief Creates a cbuffer/tbuffer from the given decl.
   ///
@@ -321,7 +327,7 @@ public:
   /// for the whole buffer. When we refer to the field VarDecl later, we need
   /// to do an extra OpAccessChain to get its pointer from the SPIR-V variable
   /// standing for the whole buffer.
-  uint32_t createCTBuffer(const HLSLBufferDecl *decl);
+  SpirvVariable *createCTBuffer(const HLSLBufferDecl *decl);
 
   /// \brief Creates a cbuffer/tbuffer from the given decl.
   ///
@@ -331,10 +337,10 @@ public:
   /// TextureBuffer is parameterized. For a such VarDecl, we need to create
   /// a corresponding SPIR-V variable for it. Later referencing of such a
   /// VarDecl does not need an extra OpAccessChain.
-  uint32_t createCTBuffer(const VarDecl *decl);
+  SpirvVariable *createCTBuffer(const VarDecl *decl);
 
   /// \brief Creates a PushConstant block from the given decl.
-  uint32_t createPushConstant(const VarDecl *decl);
+  SpirvVariable *createPushConstant(const VarDecl *decl);
 
   /// \brief Creates the $Globals cbuffer.
   void createGlobalsCBuffer(const VarDecl *var);
@@ -349,27 +355,26 @@ public:
   /// writes storage class, layout rule, and valTypeId to *info.
   ///
   /// Note: legalization specific code
-  uint32_t
+  QualType
   getTypeAndCreateCounterForPotentialAliasVar(const DeclaratorDecl *var,
-                                              bool *shouldBeAlias = nullptr,
-                                              SpirvEvalInfo *info = nullptr);
+                                              bool *shouldBeAlias = nullptr);
 
-  /// \brief Sets the <result-id> of the entry function.
-  void setEntryFunctionId(uint32_t id) { entryFunctionId = id; }
+  /// \brief Sets the entry function.
+  void setEntryFunction(SpirvFunction *fn) { entryFunction = fn; }
 
 private:
   /// The struct containing SPIR-V information of a AST Decl.
   struct DeclSpirvInfo {
     /// Default constructor to satisfy DenseMap
-    DeclSpirvInfo() : info(0), indexInCTBuffer(-1) {}
+    DeclSpirvInfo() : instr(nullptr), indexInCTBuffer(-1) {}
 
-    DeclSpirvInfo(const SpirvEvalInfo &info_, int index = -1)
-        : info(info_), indexInCTBuffer(index) {}
+    DeclSpirvInfo(SpirvInstruction *instr_, int index = -1)
+        : instr(instr_), indexInCTBuffer(index) {}
 
-    /// Implicit conversion to SpirvEvalInfo.
-    operator SpirvEvalInfo() const { return info; }
+    /// Implicit conversion to SpirvInstruction*.
+    operator SpirvInstruction *() const { return instr; }
 
-    SpirvEvalInfo info;
+    SpirvInstruction *instr;
     /// Value >= 0 means that this decl is a VarDecl inside a cbuffer/tbuffer
     /// and this is the index; value < 0 means this is just a standalone decl.
     int indexInCTBuffer;
@@ -383,18 +388,19 @@ public:
   /// \brief Returns the information for the given decl.
   ///
   /// This method will panic if the given decl is not registered.
-  SpirvEvalInfo getDeclEvalInfo(const ValueDecl *decl);
+  SpirvInstruction *getDeclEvalInfo(const ValueDecl *decl);
 
-  /// \brief Returns the <result-id> for the given function if already
+  /// \brief Returns the instruction pointer for the given function if already
   /// registered; otherwise, treats the given function as a normal decl and
-  /// returns a newly assigned <result-id> for it.
-  uint32_t getOrRegisterFnResultId(const FunctionDecl *fn);
+  /// returns a newly created instruction for it.
+  SpirvFunction *getOrRegisterFn(const FunctionDecl *fn);
 
   /// Registers that the given decl should be translated into the given spec
   /// constant.
-  void registerSpecConstant(const VarDecl *decl, uint32_t specConstant);
+  void registerSpecConstant(const VarDecl *decl,
+                            SpirvInstruction *specConstant);
 
-  /// \brief Returns the associated counter's (<result-id>, is-alias-or-not)
+  /// \brief Returns the associated counter's (instr-ptr, is-alias-or-not)
   /// pair for the given {RW|Append|Consume}StructuredBuffer variable.
   /// If indices is not nullptr, walks trhough the fields of the decl, expected
   /// to be of struct type, using the indices to find the field. Returns nullptr
@@ -419,31 +425,31 @@ public:
   /// we need to have the additional Block/BufferBlock decoration to keep
   /// type consistent. Normal translation path for structs via TypeTranslator
   /// won't attach Block/BufferBlock decoration.
-  uint32_t getCTBufferPushConstantTypeId(const DeclContext *decl);
+  const SpirvType *getCTBufferPushConstantType(const DeclContext *decl);
 
   /// \brief Returns all defined stage (builtin/input/ouput) variables in this
   /// mapper.
-  std::vector<uint32_t> collectStageVars() const;
+  std::vector<SpirvVariable *> collectStageVars() const;
 
   /// \brief Writes out the contents in the function parameter for the GS
   /// stream output to the corresponding stage output variables in a recursive
   /// manner. Returns true on success, false if errors occur.
   ///
   /// decl is the Decl with semantic string attached and will be used to find
-  /// the stage output variable to write to, value is the <result-id> for the
-  /// SPIR-V variable to read data from.
+  /// the stage output variable to write to, value is the  SPIR-V variable to
+  /// read data from.
   ///
   /// This method is specially for writing back per-vertex data at the time of
   /// OpEmitVertex in GS.
   bool writeBackOutputStream(const NamedDecl *decl, QualType type,
-                             uint32_t value);
+                             SpirvVariable *value);
 
   /// \brief Negates to get the additive inverse of SV_Position.y if requested.
-  uint32_t invertYIfRequested(uint32_t position);
+  SpirvInstruction *invertYIfRequested(SpirvInstruction *position);
 
   /// \brief Reciprocates to get the multiplicative inverse of SV_Position.w
   /// if requested.
-  uint32_t invertWIfRequested(uint32_t position);
+  SpirvInstruction *invertWIfRequested(SpirvInstruction *position);
 
   /// \brief Decorates all stage input and output variables with proper
   /// location and returns true on success.
@@ -511,11 +517,6 @@ private:
   /// construction.
   bool finalizeStageIOLocations(bool forInput);
 
-  /// \brief Wraps the given matrix type with a struct and returns the struct
-  /// type's <result-id>.
-  uint32_t getMatrixStructType(const VarDecl *matVar, spv::StorageClass,
-                               SpirvLayoutRule);
-
   /// \brief An enum class for representing what the DeclContext is used for
   enum class ContextUsageKind {
     CBuffer,
@@ -538,7 +539,7 @@ private:
   /// variable will be created as a runtime array.
   ///
   /// Panics if the DeclContext is neither HLSLBufferDecl or RecordDecl.
-  uint32_t createStructOrStructArrayVarOfExplicitLayout(
+  SpirvVariable *createStructOrStructArrayVarOfExplicitLayout(
       const DeclContext *decl, int arraySize, ContextUsageKind usageKind,
       llvm::StringRef typeName, llvm::StringRef varName);
 
@@ -572,15 +573,17 @@ private:
   bool createStageVars(const hlsl::SigPoint *sigPoint, const NamedDecl *decl,
                        bool asInput, QualType asType, uint32_t arraySize,
                        const llvm::StringRef namePrefix,
-                       llvm::Optional<uint32_t> invocationId, uint32_t *value,
-                       bool noWriteBack, SemanticInfo *inheritSemantic);
+                       llvm::Optional<SpirvInstruction *> invocationId,
+                       SpirvInstruction **value, bool noWriteBack,
+                       SemanticInfo *inheritSemantic);
 
   /// Creates the SPIR-V variable instruction for the given StageVar and returns
-  /// the <result-id>. Also sets whether the StageVar is a SPIR-V builtin and
+  /// the instruction. Also sets whether the StageVar is a SPIR-V builtin and
   /// its storage class accordingly. name will be used as the debug name when
   /// creating a stage input/output variable.
-  uint32_t createSpirvStageVar(StageVar *, const NamedDecl *decl,
-                               const llvm::StringRef name, SourceLocation);
+  SpirvVariable *createSpirvStageVar(StageVar *, const NamedDecl *decl,
+                                     const llvm::StringRef name,
+                                     SourceLocation);
 
   /// Returns true if all vk:: attributes usages are valid.
   bool validateVKAttributes(const NamedDecl *decl);
@@ -592,13 +595,13 @@ private:
   /// Methods for creating counter variables associated with the given decl.
 
   /// Creates assoicated counter variables for all AssocCounter cases (see the
-  /// comment of CounterVarFields). fields.
+  /// comment of CounterVarFields).
   void createCounterVarForDecl(const DeclaratorDecl *decl);
   /// Creates the associated counter variable for final RW/Append/Consume
   /// structured buffer. Handles AssocCounter#1 and AssocCounter#2 (see the
   /// comment of CounterVarFields).
   ///
-  /// declId is the SPIR-V <result-id> for the given decl. It should be non-zero
+  /// declId is the SPIR-V instruction for the given decl. It should be non-zero
   /// for non-alias buffers.
   ///
   /// The counter variable will be created as an alias variable (of
@@ -606,7 +609,8 @@ private:
   ///
   /// Note: isAlias - legalization specific code
   void
-  createCounterVar(const DeclaratorDecl *decl, uint32_t declId, bool isAlias,
+  createCounterVar(const DeclaratorDecl *decl, SpirvInstruction *declInstr,
+                   bool isAlias,
                    const llvm::SmallVector<uint32_t, 4> *indices = nullptr);
   /// Creates all assoicated counter variables by recursively visiting decl's
   /// fields. Handles AssocCounter#3 and AssocCounter#4 (see the comment of
@@ -616,10 +620,10 @@ private:
                               const DeclaratorDecl *decl,
                               llvm::SmallVector<uint32_t, 4> *indices);
 
-  /// Decorates varId of the given asType with proper interpolation modes
+  /// Decorates varInstr of the given asType with proper interpolation modes
   /// considering the attributes on the given decl.
   void decoratePSInterpolationMode(const NamedDecl *decl, QualType asType,
-                                   uint32_t varId);
+                                   SpirvVariable *varInstr);
 
   /// Returns the proper SPIR-V storage class (Input or Output) for the given
   /// SigPoint.
@@ -631,31 +635,33 @@ private:
 private:
   const hlsl::ShaderModel &shaderModel;
   ModuleBuilder &theBuilder;
+  SpirvBuilder &spvBuilder;
   SPIRVEmitter &theEmitter;
   const SpirvCodeGenOptions &spirvOptions;
   ASTContext &astContext;
+  SpirvContext &spvContext;
   DiagnosticsEngine &diags;
 
   TypeTranslator &typeTranslator;
 
-  uint32_t entryFunctionId;
+  SpirvFunction *entryFunction;
 
-  /// Mapping of all Clang AST decls to their <result-id>s.
+  /// Mapping of all Clang AST decls to their instruction pointers.
   llvm::DenseMap<const ValueDecl *, DeclSpirvInfo> astDecls;
+  llvm::DenseMap<const ValueDecl *, SpirvFunction *> astFunctionDecls;
   /// Vector of all defined stage variables.
   llvm::SmallVector<StageVar, 8> stageVars;
-  /// Mapping from Clang AST decls to the corresponding stage variables'
-  /// <result-id>s.
+  /// Mapping from Clang AST decls to the corresponding stage variables.
   /// This field is only used by GS for manually emitting vertices, when
-  /// we need to query the <result-id> of the output stage variables
-  /// involved in writing back. For other cases, stage variable reading
-  /// and writing is done at the time of creating that stage variable,
-  /// so that we don't need to query them again for reading and writing.
-  llvm::DenseMap<const ValueDecl *, uint32_t> stageVarIds;
+  /// we need to query the output stage variables involved in writing back. For
+  /// other cases, stage variable reading and writing is done at the time of
+  /// creating that stage variable, so that we don't need to query them again
+  /// for reading and writing.
+  llvm::DenseMap<const ValueDecl *, SpirvVariable *> stageVarIds;
   /// Vector of all defined resource variables.
   llvm::SmallVector<ResourceVar, 8> resourceVars;
   /// Mapping from {RW|Append|Consume}StructuredBuffers to their
-  /// counter variables' (<result-id>, is-alias-or-not) pairs
+  /// counter variables' (instr-ptr, is-alias-or-not) pairs
   ///
   /// conterVars holds entities of AssocCounter#1, fieldCounterVars holds
   /// entities of the rest.
@@ -663,17 +669,17 @@ private:
   llvm::DenseMap<const DeclaratorDecl *, CounterVarFields> fieldCounterVars;
 
   /// Mapping from cbuffer/tbuffer/ConstantBuffer/TextureBufer/push-constant
-  /// to the <type-id>
-  llvm::DenseMap<const DeclContext *, uint32_t> ctBufferPCTypeIds;
+  /// to the SPIR-V type.
+  llvm::DenseMap<const DeclContext *, const SpirvType *> ctBufferPCTypes;
 
-  /// <result-id> for the SPIR-V builtin variables accessed by
-  /// WaveGetLaneCount() and WaveGetLaneIndex().
+  /// The SPIR-V builtin variables accessed by WaveGetLaneCount() and
+  /// WaveGetLaneIndex().
   ///
   /// These are the only two cases that SPIR-V builtin variables are accessed
   /// using HLSL intrinsic function calls. All other builtin variables are
   /// accessed using stage IO variables.
-  uint32_t laneCountBuiltinId;
-  uint32_t laneIndexBuiltinId;
+  SpirvVariable *laneCountBuiltinVar;
+  SpirvVariable *laneIndexBuiltinVar;
 
   /// Whether the translated SPIR-V binary needs legalization.
   ///
@@ -740,22 +746,25 @@ bool SemanticInfo::isTarget() const {
 }
 
 void CounterIdAliasPair::assign(const CounterIdAliasPair &srcPair,
-                                ModuleBuilder &builder,
-                                TypeTranslator &translator) const {
+                                SpirvBuilder &builder,
+                                SpirvContext &context) const {
   assert(isAlias);
-  builder.createStore(resultId, srcPair.get(builder, translator));
+  builder.createStore(counterVar, srcPair.get(builder, context));
 }
 
 DeclResultIdMapper::DeclResultIdMapper(
-    const hlsl::ShaderModel &model, ASTContext &context, ModuleBuilder &builder,
-    SPIRVEmitter &emitter, TypeTranslator &translator, FeatureManager &features,
+    const hlsl::ShaderModel &model, ASTContext &context,
+    SpirvContext &spirvContext, ModuleBuilder &builder,
+    SpirvBuilder &spirvBuilder, SPIRVEmitter &emitter,
+    TypeTranslator &translator, FeatureManager &features,
     const SpirvCodeGenOptions &options)
-    : shaderModel(model), theBuilder(builder), theEmitter(emitter),
-      spirvOptions(options), astContext(context),
-      diags(context.getDiagnostics()), typeTranslator(translator),
-      entryFunctionId(0), laneCountBuiltinId(0), laneIndexBuiltinId(0),
+    : shaderModel(model), theBuilder(builder), spvBuilder(spirvBuilder),
+      theEmitter(emitter), spirvOptions(options), astContext(context),
+      spvContext(spirvContext), diags(context.getDiagnostics()),
+      typeTranslator(translator), entryFunction(nullptr),
+      laneCountBuiltinVar(nullptr), laneIndexBuiltinVar(nullptr),
       needsLegalization(false),
-      glPerVertex(model, context, builder, typeTranslator) {}
+      glPerVertex(model, context, spirvContext, spirvBuilder) {}
 
 bool DeclResultIdMapper::decorateStageIOLocations() {
   // Try both input and output even if input location assignment failed

+ 8 - 2
tools/clang/lib/SPIRV/EmitVisitor.cpp

@@ -305,8 +305,14 @@ bool EmitVisitor::visit(SpirvDecoration *inst) {
   if (inst->isMemberDecoration())
     curInst.push_back(inst->getMemberIndex());
   curInst.push_back(static_cast<uint32_t>(inst->getDecoration()));
-  curInst.insert(curInst.end(), inst->getParams().begin(),
-                 inst->getParams().end());
+  if (!inst->getParams().empty()) {
+    curInst.insert(curInst.end(), inst->getParams().begin(),
+                   inst->getParams().end());
+  }
+  if (!inst->getIdParams().empty()) {
+    curInst.insert(curInst.end(), inst->getIdParams().begin(),
+                   inst->getIdParams().end());
+  }
   finalizeInstruction();
   return true;
 }

+ 96 - 93
tools/clang/lib/SPIRV/GlPerVertex.cpp

@@ -61,12 +61,12 @@ inline bool hasGSPrimitiveTypeQualifier(const DeclaratorDecl *decl) {
 } // anonymous namespace
 
 GlPerVertex::GlPerVertex(const hlsl::ShaderModel &sm, ASTContext &context,
-                         ModuleBuilder &builder, TypeTranslator &translator)
-    : shaderModel(sm), astContext(context), theBuilder(builder), inClipVar(0),
-      inCullVar(0), outClipVar(0), outCullVar(0), inArraySize(0),
-      outArraySize(0), inClipArraySize(1), outClipArraySize(1),
-      inCullArraySize(1), outCullArraySize(1), inSemanticStrs(2, ""),
-      outSemanticStrs(2, "") {}
+                         SpirvContext &spirvContext, SpirvBuilder &spirvBuilder)
+    : shaderModel(sm), astContext(context), spvContext(spirvContext),
+      spvBuilder(spirvBuilder), inClipVar(nullptr), inCullVar(nullptr),
+      outClipVar(nullptr), outCullVar(nullptr), inArraySize(0), outArraySize(0),
+      inClipArraySize(1), outClipArraySize(1), inCullArraySize(1),
+      outCullArraySize(1), inSemanticStrs(2, ""), outSemanticStrs(2, "") {}
 
 void GlPerVertex::generateVars(uint32_t inArrayLen, uint32_t outArrayLen) {
   inArraySize = inArrayLen;
@@ -86,8 +86,8 @@ void GlPerVertex::generateVars(uint32_t inArrayLen, uint32_t outArrayLen) {
                                            outCullArraySize);
 }
 
-llvm::SmallVector<uint32_t, 2> GlPerVertex::getStageInVars() const {
-  llvm::SmallVector<uint32_t, 2> vars;
+llvm::SmallVector<SpirvVariable *, 2> GlPerVertex::getStageInVars() const {
+  llvm::SmallVector<SpirvVariable *, 2> vars;
 
   if (inClipVar)
     vars.push_back(inClipVar);
@@ -97,8 +97,8 @@ llvm::SmallVector<uint32_t, 2> GlPerVertex::getStageInVars() const {
   return vars;
 }
 
-llvm::SmallVector<uint32_t, 2> GlPerVertex::getStageOutVars() const {
-  llvm::SmallVector<uint32_t, 2> vars;
+llvm::SmallVector<SpirvVariable *, 2> GlPerVertex::getStageOutVars() const {
+  llvm::SmallVector<SpirvVariable *, 2> vars;
 
   if (outClipVar)
     vars.push_back(outClipVar);
@@ -110,10 +110,10 @@ llvm::SmallVector<uint32_t, 2> GlPerVertex::getStageOutVars() const {
 
 void GlPerVertex::requireCapabilityIfNecessary() {
   if (!inClipType.empty() || !outClipType.empty())
-    theBuilder.requireCapability(spv::Capability::ClipDistance);
+    spvBuilder.requireCapability(spv::Capability::ClipDistance);
 
   if (!inCullType.empty() || !outCullType.empty())
-    theBuilder.requireCapability(spv::Capability::CullDistance);
+    spvBuilder.requireCapability(spv::Capability::CullDistance);
 }
 
 bool GlPerVertex::recordGlPerVertexDeclFacts(const DeclaratorDecl *decl,
@@ -314,35 +314,36 @@ void GlPerVertex::calculateClipCullDistanceArraySize() {
   updateSizeAndOffset(outCullType, &outCullOffset, &outCullArraySize);
 }
 
-uint32_t GlPerVertex::createClipCullDistanceVar(bool asInput, bool isClip,
-                                                uint32_t arraySize) {
-  uint32_t type = theBuilder.getArrayType(
-      theBuilder.getFloat32Type(), theBuilder.getConstantUint32(arraySize));
+SpirvVariable *GlPerVertex::createClipCullDistanceVar(bool asInput, bool isClip,
+                                                      uint32_t arraySize) {
+  const ArrayType *type =
+      spvContext.getArrayType(spvContext.getFloatType(32), arraySize);
+
   if (asInput && inArraySize != 0) {
-    type = theBuilder.getArrayType(type,
-                                   theBuilder.getConstantUint32(inArraySize));
+    type = spvContext.getArrayType(type, inArraySize);
   } else if (!asInput && outArraySize != 0) {
-    type = theBuilder.getArrayType(type,
-                                   theBuilder.getConstantUint32(outArraySize));
+    type = spvContext.getArrayType(type, outArraySize);
   }
 
   spv::StorageClass sc =
       asInput ? spv::StorageClass::Input : spv::StorageClass::Output;
 
-  auto id = theBuilder.addStageBuiltinVar(type, sc,
-                                          isClip ? spv::BuiltIn::ClipDistance
-                                                 : spv::BuiltIn::CullDistance);
+  SpirvVariable *var = spvBuilder.addStageBuiltinVar(
+      type, sc,
+      isClip ? spv::BuiltIn::ClipDistance : spv::BuiltIn::CullDistance);
+
   const auto index = isClip ? gClipDistanceIndex : gCullDistanceIndex;
-  theBuilder.decorateHlslSemantic(id, asInput ? inSemanticStrs[index]
-                                              : outSemanticStrs[index]);
-  return id;
+  spvBuilder.decorateHlslSemantic(var, asInput ? inSemanticStrs[index]
+                                               : outSemanticStrs[index]);
+  return var;
 }
 
 bool GlPerVertex::tryToAccess(hlsl::SigPoint::Kind sigPointKind,
                               hlsl::Semantic::Kind semanticKind,
                               uint32_t semanticIndex,
-                              llvm::Optional<uint32_t> invocationId,
-                              uint32_t *value, bool noWriteBack) {
+                              llvm::Optional<SpirvInstruction *> invocationId,
+                              SpirvInstruction **value, bool noWriteBack) {
+  assert(value);
   // invocationId should only be used for HSPCOut.
   assert(invocationId.hasValue() ? sigPointKind == hlsl::SigPoint::Kind::HSCPOut
                                  : true);
@@ -379,15 +380,16 @@ bool GlPerVertex::tryToAccess(hlsl::SigPoint::Kind sigPointKind,
   return false;
 }
 
-uint32_t GlPerVertex::readClipCullArrayAsType(bool isClip, uint32_t offset,
-                                              QualType asType) const {
-  const uint32_t clipCullVar = isClip ? inClipVar : inCullVar;
+SpirvInstruction *GlPerVertex::readClipCullArrayAsType(bool isClip,
+                                                       uint32_t offset,
+                                                       QualType asType) const {
+  SpirvVariable *clipCullVar = isClip ? inClipVar : inCullVar;
 
   // The ClipDistance/CullDistance is always an float array. We are accessing
   // it using pointers, which should be of pointer to float type.
-  const uint32_t f32Type = theBuilder.getFloat32Type();
-  const uint32_t ptrType =
-      theBuilder.getPointerType(f32Type, spv::StorageClass::Input);
+  const FloatType *f32Type = spvContext.getFloatType(32);
+  const SpirvPointerType *ptrType =
+      spvContext.getPointerType(f32Type, spv::StorageClass::Input);
 
   if (inArraySize == 0) {
     // The input builtin does not have extra arrayness. Only need one index
@@ -397,25 +399,25 @@ uint32_t GlPerVertex::readClipCullArrayAsType(bool isClip, uint32_t offset,
     uint32_t count = {};
 
     if (isScalarType(asType)) {
-      const uint32_t offsetId = theBuilder.getConstantUint32(offset);
-      const uint32_t ptr =
-          theBuilder.createAccessChain(ptrType, clipCullVar, {offsetId});
-      return theBuilder.createLoad(f32Type, ptr);
+      auto *spirvConstant = spvContext.getConstantUint32(offset);
+      auto *ptr =
+          spvBuilder.createAccessChain(ptrType, clipCullVar, {spirvConstant});
+      return spvBuilder.createLoad(astContext.FloatTy, ptr);
     }
 
     if (isVectorType(asType, &elemType, &count)) {
       // The target SV_ClipDistance/SV_CullDistance variable is of vector
       // type, then we need to construct a vector out of float array elements.
-      llvm::SmallVector<uint32_t, 4> elements;
+      llvm::SmallVector<SpirvInstruction *, 4> elements;
       for (uint32_t i = 0; i < count; ++i) {
         // Read elements sequentially from the float array
-        const uint32_t offsetId = theBuilder.getConstantUint32(offset + i);
-        const uint32_t ptr =
-            theBuilder.createAccessChain(ptrType, clipCullVar, {offsetId});
-        elements.push_back(theBuilder.createLoad(f32Type, ptr));
+        auto *spirvConstant = spvContext.getConstantUint32(offset + i);
+        auto *ptr =
+            spvBuilder.createAccessChain(ptrType, clipCullVar, {spirvConstant});
+        elements.push_back(spvBuilder.createLoad(astContext.FloatTy, ptr));
       }
-      return theBuilder.createCompositeConstruct(
-          theBuilder.getVecType(f32Type, count), elements);
+      return spvBuilder.createCompositeConstruct(
+          spvContext.getVectorType(f32Type, count), elements);
     }
 
     llvm_unreachable("SV_ClipDistance/SV_CullDistance not float or vector of "
@@ -429,49 +431,50 @@ uint32_t GlPerVertex::readClipCullArrayAsType(bool isClip, uint32_t offset,
   // for indexing into the gl_PerVertex struct, and the third one for reading
   // the correct element in the float array for ClipDistance/CullDistance.
 
-  llvm::SmallVector<uint32_t, 8> arrayElements;
+  llvm::SmallVector<SpirvInstruction *, 8> arrayElements;
   QualType elemType = {};
   uint32_t count = {};
-  uint32_t arrayType = {};
-  uint32_t arraySize = theBuilder.getConstantUint32(inArraySize);
+  const ArrayType *arrayType = nullptr;
 
   if (isScalarType(asType)) {
-    arrayType = theBuilder.getArrayType(f32Type, arraySize);
+    arrayType = spvContext.getArrayType(f32Type, inArraySize);
     for (uint32_t i = 0; i < inArraySize; ++i) {
-      const uint32_t ptr = theBuilder.createAccessChain(
+      auto *ptr = spvBuilder.createAccessChain(
           ptrType, clipCullVar,
-          {theBuilder.getConstantUint32(i), // Block array index
-           theBuilder.getConstantUint32(offset)});
-      arrayElements.push_back(theBuilder.createLoad(f32Type, ptr));
+          {spvContext.getConstantUint32(i), // Block array index
+           spvContext.getConstantUint32(offset)});
+      arrayElements.push_back(spvBuilder.createLoad(astContext.FloatTy, ptr));
     }
   } else if (isVectorType(asType, &elemType, &count)) {
-    arrayType = theBuilder.getArrayType(theBuilder.getVecType(f32Type, count),
-                                        arraySize);
+    arrayType = spvContext.getArrayType(
+        spvContext.getVectorType(f32Type, count), inArraySize);
+
     for (uint32_t i = 0; i < inArraySize; ++i) {
       // For each gl_PerVertex block, we need to read a vector from it.
-      llvm::SmallVector<uint32_t, 4> vecElements;
+      llvm::SmallVector<SpirvInstruction *, 4> vecElements;
       for (uint32_t j = 0; j < count; ++j) {
-        const uint32_t ptr = theBuilder.createAccessChain(
+        auto *ptr = spvBuilder.createAccessChain(
             ptrType, clipCullVar,
             // Block array index
-            {theBuilder.getConstantUint32(i),
+            {spvContext.getConstantUint32(i),
              // Read elements sequentially from the float array
-             theBuilder.getConstantUint32(offset + j)});
-        vecElements.push_back(theBuilder.createLoad(f32Type, ptr));
+             spvContext.getConstantUint32(offset + j)});
+        vecElements.push_back(spvBuilder.createLoad(astContext.FloatTy, ptr));
       }
-      arrayElements.push_back(theBuilder.createCompositeConstruct(
-          theBuilder.getVecType(f32Type, count), vecElements));
+      arrayElements.push_back(spvBuilder.createCompositeConstruct(
+          spvContext.getVectorType(f32Type, count), vecElements));
     }
   } else {
     llvm_unreachable("SV_ClipDistance/SV_CullDistance not float or vector of "
                      "float case sneaked in");
   }
 
-  return theBuilder.createCompositeConstruct(arrayType, arrayElements);
+  return spvBuilder.createCompositeConstruct(arrayType, arrayElements);
 }
 
 bool GlPerVertex::readField(hlsl::Semantic::Kind semanticKind,
-                            uint32_t semanticIndex, uint32_t *value) {
+                            uint32_t semanticIndex, SpirvInstruction **value) {
+  assert(value);
   switch (semanticKind) {
   case hlsl::Semantic::Kind::ClipDistance: {
     const auto offsetIter = inClipOffset.find(semanticIndex);
@@ -501,15 +504,15 @@ bool GlPerVertex::readField(hlsl::Semantic::Kind semanticKind,
 }
 
 void GlPerVertex::writeClipCullArrayFromType(
-    llvm::Optional<uint32_t> invocationId, bool isClip, uint32_t offset,
-    QualType fromType, uint32_t fromValue) const {
-  const uint32_t clipCullVar = isClip ? outClipVar : outCullVar;
+    llvm::Optional<SpirvInstruction *> invocationId, bool isClip,
+    uint32_t offset, QualType fromType, SpirvInstruction *fromValue) const {
+  auto *clipCullVar = isClip ? outClipVar : outCullVar;
 
   // The ClipDistance/CullDistance is always an float array. We are accessing
   // it using pointers, which should be of pointer to float type.
-  const uint32_t f32Type = theBuilder.getFloat32Type();
-  const uint32_t ptrType =
-      theBuilder.getPointerType(f32Type, spv::StorageClass::Output);
+  const FloatType *f32Type = spvContext.getFloatType(32);
+  const SpirvPointerType *ptrType =
+      spvContext.getPointerType(f32Type, spv::StorageClass::Output);
 
   if (outArraySize == 0) {
     // The output builtin does not have extra arrayness. Only need one index
@@ -519,10 +522,10 @@ void GlPerVertex::writeClipCullArrayFromType(
     uint32_t count = {};
 
     if (isScalarType(fromType)) {
-      const uint32_t offsetId = theBuilder.getConstantUint32(offset);
-      const uint32_t ptr =
-          theBuilder.createAccessChain(ptrType, clipCullVar, {offsetId});
-      theBuilder.createStore(ptr, fromValue);
+      auto *constant = spvContext.getConstantUint32(offset);
+      auto *ptr =
+          spvBuilder.createAccessChain(ptrType, clipCullVar, {constant});
+      spvBuilder.createStore(ptr, fromValue);
       return;
     }
 
@@ -531,12 +534,12 @@ void GlPerVertex::writeClipCullArrayFromType(
       // type. We need to write each component in the vector out.
       for (uint32_t i = 0; i < count; ++i) {
         // Write elements sequentially into the float array
-        const uint32_t offsetId = theBuilder.getConstantUint32(offset + i);
-        const uint32_t ptr =
-            theBuilder.createAccessChain(ptrType, clipCullVar, {offsetId});
-        const uint32_t subValue =
-            theBuilder.createCompositeExtract(f32Type, fromValue, {i});
-        theBuilder.createStore(ptr, subValue);
+        auto *constant = spvContext.getConstantUint32(offset + i);
+        auto *ptr =
+            spvBuilder.createAccessChain(ptrType, clipCullVar, {constant});
+        auto *subValue = spvBuilder.createCompositeExtract(astContext.FloatTy,
+                                                           fromValue, {i});
+        spvBuilder.createStore(ptr, subValue);
       }
       return;
     }
@@ -558,31 +561,31 @@ void GlPerVertex::writeClipCullArrayFromType(
   // for indexing into the gl_PerVertex struct, and the third one for the
   // correct element in the float array for ClipDistance/CullDistance.
 
-  uint32_t arrayIndex = invocationId.getValue();
+  SpirvInstruction *arrayIndex = invocationId.getValue();
   QualType elemType = {};
   uint32_t count = {};
 
   if (isScalarType(fromType)) {
-    const uint32_t ptr =
-        theBuilder.createAccessChain(ptrType, clipCullVar,
-                                     {arrayIndex, // Block array index
-                                      theBuilder.getConstantUint32(offset)});
-    theBuilder.createStore(ptr, fromValue);
+    auto *ptr = spvBuilder.createAccessChain(
+        ptrType, clipCullVar,
+        {arrayIndex, spvContext.getConstantUint32(offset)});
+    spvBuilder.createStore(ptr, fromValue);
     return;
   }
 
   if (isVectorType(fromType, &elemType, &count)) {
     // For each gl_PerVertex block, we need to write a vector into it.
     for (uint32_t i = 0; i < count; ++i) {
-      const uint32_t ptr = theBuilder.createAccessChain(
+      auto *ptr = spvBuilder.createAccessChain(
           ptrType, clipCullVar,
           // Block array index
           {arrayIndex,
            // Write elements sequentially into the float array
-           theBuilder.getConstantUint32(offset + i)});
-      const uint32_t subValue =
-          theBuilder.createCompositeExtract(f32Type, fromValue, {i});
-      theBuilder.createStore(ptr, subValue);
+           spvContext.getConstantUint32(offset + i)});
+
+      auto *subValue =
+          spvBuilder.createCompositeExtract(astContext.FloatTy, fromValue, {i});
+      spvBuilder.createStore(ptr, subValue);
     }
     return;
   }
@@ -593,8 +596,8 @@ void GlPerVertex::writeClipCullArrayFromType(
 
 bool GlPerVertex::writeField(hlsl::Semantic::Kind semanticKind,
                              uint32_t semanticIndex,
-                             llvm::Optional<uint32_t> invocationId,
-                             uint32_t *value) {
+                             llvm::Optional<SpirvInstruction *> invocationId,
+                             SpirvInstruction **value) {
   // Similar to the writing logic in DeclResultIdMapper::createStageVars():
   //
   // Unlike reading, which may require us to read stand-alone builtins and

+ 29 - 26
tools/clang/lib/SPIRV/GlPerVertex.h

@@ -14,12 +14,11 @@
 #include "dxc/DXIL/DxilShaderModel.h"
 #include "dxc/DXIL/DxilSigPoint.h"
 #include "clang/SPIRV/ModuleBuilder.h"
+#include "clang/SPIRV/SpirvBuilder.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/Optional.h"
 #include "llvm/ADT/SmallVector.h"
 
-#include "TypeTranslator.h"
-
 namespace clang {
 namespace spirv {
 
@@ -45,7 +44,7 @@ namespace spirv {
 class GlPerVertex {
 public:
   GlPerVertex(const hlsl::ShaderModel &sm, ASTContext &context,
-              ModuleBuilder &builder, TypeTranslator &translator);
+              SpirvContext &spvContext, SpirvBuilder &spvBuilder);
 
   /// Records a declaration of SV_ClipDistance/SV_CullDistance so later
   /// we can caculate the ClipDistance/CullDistance array layout.
@@ -64,10 +63,10 @@ public:
   /// and calculateClipCullDistanceArraySize().
   void generateVars(uint32_t inputArrayLength, uint32_t outputArrayLength);
 
-  /// Returns the <result-id>s for stage input variables.
-  llvm::SmallVector<uint32_t, 2> getStageInVars() const;
-  /// Returns the <result-id>s for stage output variables.
-  llvm::SmallVector<uint32_t, 2> getStageOutVars() const;
+  /// Returns the stage input variables.
+  llvm::SmallVector<SpirvVariable *, 2> getStageInVars() const;
+  /// Returns the stage output variables.
+  llvm::SmallVector<SpirvVariable *, 2> getStageOutVars() const;
 
   /// Requires the ClipDistance/CullDistance capability if we've seen
   /// definition of SV_ClipDistance/SV_CullDistance.
@@ -84,12 +83,13 @@ public:
   /// If invocation (should only be used for HS) is not llvm::None, only
   /// accesses the element at the invocation offset in the gl_PerVeterx array.
   ///
-  /// Emits SPIR-V instructions and returns true if we are accessing builtins
+  /// Creates SPIR-V instructions and returns true if we are accessing builtins
   /// that are ClipDistance or CullDistance. Does nothing and returns true if
   /// accessing builtins for others. Returns false if errors occurs.
   bool tryToAccess(hlsl::SigPoint::Kind sigPoint, hlsl::Semantic::Kind,
-                   uint32_t semanticIndex, llvm::Optional<uint32_t> invocation,
-                   uint32_t *value, bool noWriteBack);
+                   uint32_t semanticIndex,
+                   llvm::Optional<SpirvInstruction *> invocation,
+                   SpirvInstruction **value, bool noWriteBack);
 
 private:
   template <unsigned N>
@@ -100,28 +100,30 @@ private:
   }
 
   /// Creates a stand-alone ClipDistance/CullDistance builtin variable.
-  uint32_t createClipCullDistanceVar(bool asInput, bool isClip,
-                                     uint32_t arraySize);
+  SpirvVariable *createClipCullDistanceVar(bool asInput, bool isClip,
+                                           uint32_t arraySize);
 
-  /// Emits SPIR-V instructions for reading the data starting from offset in
+  /// Creates SPIR-V instructions for reading the data starting from offset in
   /// the ClipDistance/CullDistance builtin. The data read will be transformed
   /// into the given type asType.
-  uint32_t readClipCullArrayAsType(bool isClip, uint32_t offset,
-                                   QualType asType) const;
-  /// Emits SPIR-V instructions to read a field in gl_PerVertex.
+  SpirvInstruction *readClipCullArrayAsType(bool isClip, uint32_t offset,
+                                            QualType asType) const;
+  /// Creates SPIR-V instructions to read a field in gl_PerVertex.
   bool readField(hlsl::Semantic::Kind semanticKind, uint32_t semanticIndex,
-                 uint32_t *value);
+                 SpirvInstruction **value);
 
-  /// Emits SPIR-V instructions for writing data into the ClipDistance/
+  /// Creates SPIR-V instructions for writing data into the ClipDistance/
   /// CullDistance builtin starting from offset. The value to be written is
   /// fromValue, whose type is fromType. Necessary transformations will be
   /// generated to make sure type correctness.
-  void writeClipCullArrayFromType(llvm::Optional<uint32_t> invocationId,
-                                  bool isClip, uint32_t offset,
-                                  QualType fromType, uint32_t fromValue) const;
-  /// Emits SPIR-V instructions to write a field in gl_PerVertex.
+  void
+  writeClipCullArrayFromType(llvm::Optional<SpirvInstruction *> invocationId,
+                             bool isClip, uint32_t offset, QualType fromType,
+                             SpirvInstruction *fromValue) const;
+  /// Creates SPIR-V instructions to write a field in gl_PerVertex.
   bool writeField(hlsl::Semantic::Kind semanticKind, uint32_t semanticIndex,
-                  llvm::Optional<uint32_t> invocationId, uint32_t *value);
+                  llvm::Optional<SpirvInstruction *> invocationId,
+                  SpirvInstruction **value);
 
   /// Internal implementation for recordClipCullDistanceDecl().
   bool doGlPerVertexFacts(const DeclaratorDecl *decl, QualType type,
@@ -133,11 +135,12 @@ private:
 
   const hlsl::ShaderModel &shaderModel;
   ASTContext &astContext;
-  ModuleBuilder &theBuilder;
+  SpirvContext &spvContext;
+  SpirvBuilder &spvBuilder;
 
   /// Input/output ClipDistance/CullDistance variable.
-  uint32_t inClipVar, inCullVar;
-  uint32_t outClipVar, outCullVar;
+  SpirvVariable *inClipVar, *inCullVar;
+  SpirvVariable *outClipVar, *outCullVar;
 
   /// The array size for the input/output gl_PerVertex block member variables.
   /// HS input and output, DS input, GS input has an additional level of

+ 53 - 0
tools/clang/lib/SPIRV/SPIRVContext.cpp

@@ -249,6 +249,29 @@ SpirvContext::getStructType(llvm::ArrayRef<StructType::FieldInfo> fields,
   return structTypes.back();
 }
 
+const HybridStructType *SpirvContext::getHybridStructType(
+    llvm::ArrayRef<HybridStructType::FieldInfo> fields, llvm::StringRef name,
+    bool isReadOnly, HybridStructType::InterfaceType interfaceType) {
+  // We are creating a temporary struct type here for querying whether the
+  // same type was already created. It is a little bit costly, but we can
+  // avoid allocating directly from the bump pointer allocator, from which
+  // then we are unable to reclaim until the allocator itself is destroyed.
+
+  HybridStructType type(fields, name, isReadOnly, interfaceType);
+
+  auto found = std::find_if(
+      hybridStructTypes.begin(), hybridStructTypes.end(),
+      [&type](const HybridStructType *cachedType) { return type == *cachedType; });
+
+  if (found != hybridStructTypes.end())
+    return *found;
+
+  hybridStructTypes.push_back(
+      new (this) HybridStructType(fields, name, isReadOnly, interfaceType));
+
+  return hybridStructTypes.back();
+}
+
 const SpirvPointerType *SpirvContext::getPointerType(const SpirvType *pointee,
                                                      spv::StorageClass sc) {
   auto foundPointee = pointerTypes.find(pointee);
@@ -293,6 +316,17 @@ const StructType *SpirvContext::getByteAddressBufferType(bool isWritable) {
                        !isWritable);
 }
 
+const StructType *SpirvContext::getACSBufferCounterType() {
+  // Create int32.
+  const auto *int32Type = getSIntType(32);
+
+  // Create a struct containing the integer counter as its only member.
+  const StructType *type =
+      getStructType({int32Type}, "type.ACSBuffer.counter", {"counter"});
+
+  return type;
+}
+
 SpirvConstant *SpirvContext::getConstantUint32(uint32_t value) {
   const IntegerType *intType = getUIntType(32);
   SpirvConstantInteger tempConstant(intType, value);
@@ -312,5 +346,24 @@ SpirvConstant *SpirvContext::getConstantUint32(uint32_t value) {
   return intConst;
 }
 
+SpirvConstant *SpirvContext::getConstantInt32(int32_t value) {
+  const IntegerType *intType = getSIntType(32);
+  SpirvConstantInteger tempConstant(intType, value);
+
+  auto found =
+      std::find_if(integerConstants.begin(), integerConstants.end(),
+                   [&tempConstant](SpirvConstantInteger *cachedConstant) {
+                     return tempConstant == *cachedConstant;
+                   });
+
+  if (found != integerConstants.end())
+    return *found;
+
+  // Couldn't find the constant. Create one.
+  auto *intConst = new (this) SpirvConstantInteger(intType, value);
+  integerConstants.push_back(intConst);
+  return intConst;
+}
+
 } // end namespace spirv
 } // end namespace clang

+ 54 - 39
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -602,12 +602,13 @@ SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci)
       entryFunctionName(ci.getCodeGenOpts().HLSLEntryFunction),
       shaderModel(*hlsl::ShaderModel::GetByName(
           ci.getCodeGenOpts().HLSLProfile.c_str())),
-      theContext(), featureManager(diags, spirvOptions),
+      theContext(), spvContext(astContext), featureManager(diags, spirvOptions),
       theBuilder(&theContext, &featureManager, spirvOptions),
+      spvBuilder(astContext, spvContext, &featureManager, spirvOptions),
       typeTranslator(astContext, theBuilder, diags, spirvOptions),
       declIdMapper(shaderModel, astContext, theBuilder, *this, typeTranslator,
                    featureManager, spirvOptions),
-      entryFunctionId(0), curFunction(nullptr), curThis(0),
+      entryFunction(nullptr), curFunction(nullptr), curThis(0),
       seenPushConstantAt(), isSpecConstantMode(false),
       foundNonUniformResourceIndex(false), needsLegalization(false),
       mainSourceFileId(0) {
@@ -699,6 +700,8 @@ void SPIRVEmitter::HandleTranslationUnit(ASTContext &context) {
 
   theBuilder.addEntryPoint(getSpirvShaderStage(shaderModel), entryFunctionId,
                            entryFunctionName, declIdMapper.collectStageVars());
+  spvBuilder.addEntryPoint(getSpirvShaderStage(shaderModel), entryFunction,
+                           entryFunctionName, interfaces);
 
   // Add Location decorations to stage input/output variables.
   if (!declIdMapper.decorateStageIOLocations())
@@ -9329,17 +9332,17 @@ bool SPIRVEmitter::processGeometryShaderAttributes(const FunctionDecl *decl,
   bool success = true;
   assert(shaderModel.IsGS());
   if (auto *vcAttr = decl->getAttr<HLSLMaxVertexCountAttr>()) {
-    theBuilder.addExecutionMode(entryFunctionId,
-                                spv::ExecutionMode::OutputVertices,
-                                {static_cast<uint32_t>(vcAttr->getCount())});
+    spvBuilder.addExecutionMode(
+        entryFunction, spv::ExecutionMode::OutputVertices,
+        {static_cast<uint32_t>(vcAttr->getCount())}, decl->getLocation());
   }
 
   uint32_t invocations = 1;
   if (auto *instanceAttr = decl->getAttr<HLSLInstanceAttr>()) {
     invocations = static_cast<uint32_t>(instanceAttr->getCount());
   }
-  theBuilder.addExecutionMode(entryFunctionId, spv::ExecutionMode::Invocations,
-                              {invocations});
+  spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::Invocations,
+                              {invocations}, decl->getLocation());
 
   // Only one primitive type is permitted for the geometry shader.
   bool outPoint = false, outLine = false, outTriangle = false, inPoint = false,
@@ -9351,16 +9354,19 @@ bool SPIRVEmitter::processGeometryShaderAttributes(const FunctionDecl *decl,
     if (param->hasAttr<HLSLInOutAttr>()) {
       const auto paramType = param->getType();
       if (hlsl::IsHLSLTriangleStreamType(paramType) && !outTriangle) {
-        theBuilder.addExecutionMode(
-            entryFunctionId, spv::ExecutionMode::OutputTriangleStrip, {});
+        spvBuilder.addExecutionMode(entryFunction,
+                                    spv::ExecutionMode::OutputTriangleStrip, {},
+                                    param->getLocation());
         outTriangle = true;
       } else if (hlsl::IsHLSLLineStreamType(paramType) && !outLine) {
-        theBuilder.addExecutionMode(entryFunctionId,
-                                    spv::ExecutionMode::OutputLineStrip, {});
+        spvBuilder.addExecutionMode(entryFunction,
+                                    spv::ExecutionMode::OutputLineStrip, {},
+                                    param->getLocation());
         outLine = true;
       } else if (hlsl::IsHLSLPointStreamType(paramType) && !outPoint) {
-        theBuilder.addExecutionMode(entryFunctionId,
-                                    spv::ExecutionMode::OutputPoints, {});
+        spvBuilder.addExecutionMode(entryFunction,
+                                    spv::ExecutionMode::OutputPoints, {},
+                                    param->getLocation());
         outPoint = true;
       }
       // An output stream parameter will not have the input primitive type
@@ -9371,28 +9377,31 @@ bool SPIRVEmitter::processGeometryShaderAttributes(const FunctionDecl *decl,
     // Add an execution mode based on the input primitive type. Do not add an
     // execution mode more than once.
     if (param->hasAttr<HLSLPointAttr>() && !inPoint) {
-      theBuilder.addExecutionMode(entryFunctionId,
-                                  spv::ExecutionMode::InputPoints, {});
+      spvBuilder.addExecutionMode(entryFunction,
+                                  spv::ExecutionMode::InputPoints, {},
+                                  param->getLocation());
       *arraySize = 1;
       inPoint = true;
     } else if (param->hasAttr<HLSLLineAttr>() && !inLine) {
-      theBuilder.addExecutionMode(entryFunctionId,
-                                  spv::ExecutionMode::InputLines, {});
+      spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::InputLines,
+                                  {}, param->getLocation());
       *arraySize = 2;
       inLine = true;
     } else if (param->hasAttr<HLSLTriangleAttr>() && !inTriangle) {
-      theBuilder.addExecutionMode(entryFunctionId,
-                                  spv::ExecutionMode::Triangles, {});
+      spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::Triangles,
+                                  {}, param->getLocation());
       *arraySize = 3;
       inTriangle = true;
     } else if (param->hasAttr<HLSLLineAdjAttr>() && !inLineAdj) {
-      theBuilder.addExecutionMode(entryFunctionId,
-                                  spv::ExecutionMode::InputLinesAdjacency, {});
+      spvBuilder.addExecutionMode(entryFunction,
+                                  spv::ExecutionMode::InputLinesAdjacency, {},
+                                  param->getLocation());
       *arraySize = 4;
       inLineAdj = true;
     } else if (param->hasAttr<HLSLTriangleAdjAttr>() && !inTriangleAdj) {
-      theBuilder.addExecutionMode(
-          entryFunctionId, spv::ExecutionMode::InputTrianglesAdjacency, {});
+      spvBuilder.addExecutionMode(entryFunction,
+                                  spv::ExecutionMode::InputTrianglesAdjacency,
+                                  {}, param->getLocation());
       *arraySize = 6;
       inTriangleAdj = true;
     }
@@ -9414,18 +9423,21 @@ bool SPIRVEmitter::processGeometryShaderAttributes(const FunctionDecl *decl,
 }
 
 void SPIRVEmitter::processPixelShaderAttributes(const FunctionDecl *decl) {
-  theBuilder.addExecutionMode(entryFunctionId,
-                              spv::ExecutionMode::OriginUpperLeft, {});
+  spvBuilder.addExecutionMode(entryFunction,
+                              spv::ExecutionMode::OriginUpperLeft, {},
+                              decl->getLocation());
   if (decl->getAttr<HLSLEarlyDepthStencilAttr>()) {
-    theBuilder.addExecutionMode(entryFunctionId,
-                                spv::ExecutionMode::EarlyFragmentTests, {});
+    spvBuilder.addExecutionMode(entryFunction,
+                                spv::ExecutionMode::EarlyFragmentTests, {},
+                                decl->getLocation());
   }
   if (decl->getAttr<VKPostDepthCoverageAttr>()) {
     theBuilder.addExtension(Extension::KHR_post_depth_coverage,
                             "[[vk::post_depth_coverage]]", decl->getLocation());
     theBuilder.requireCapability(spv::Capability::SampleMaskPostDepthCoverage);
-    theBuilder.addExecutionMode(entryFunctionId,
-                                spv::ExecutionMode::PostDepthCoverage, {});
+    spvBuilder.addExecutionMode(entryFunction,
+                                spv::ExecutionMode::PostDepthCoverage, {},
+                                decl->getLocation());
   }
 }
 
@@ -9439,8 +9451,8 @@ void SPIRVEmitter::processComputeShaderAttributes(const FunctionDecl *decl) {
     z = static_cast<uint32_t>(numThreadsAttr->getZ());
   }
 
-  theBuilder.addExecutionMode(entryFunctionId, spv::ExecutionMode::LocalSize,
-                              {x, y, z});
+  spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize,
+                              {x, y, z}, decl->getLocation());
 }
 
 bool SPIRVEmitter::processTessellationShaderAttributes(
@@ -9461,7 +9473,8 @@ bool SPIRVEmitter::processTessellationShaderAttributes(
                 domain->getLocation());
       return false;
     }
-    theBuilder.addExecutionMode(entryFunctionId, hsExecMode, {});
+    spvBuilder.addExecutionMode(entryFunction, hsExecMode, {},
+                                decl->getLocation());
   }
 
   // Early return for domain shaders as domain shaders only takes the 'domain'
@@ -9488,7 +9501,8 @@ bool SPIRVEmitter::processTessellationShaderAttributes(
                 partitioning->getLocation());
       return false;
     }
-    theBuilder.addExecutionMode(entryFunctionId, hsExecMode, {});
+    spvBuilder.addExecutionMode(entryFunction, hsExecMode, {},
+                                decl->getLocation());
   }
   if (auto *outputTopology = decl->getAttr<HLSLOutputTopologyAttr>()) {
     const auto topology = outputTopology->getTopology().lower();
@@ -9502,7 +9516,8 @@ bool SPIRVEmitter::processTessellationShaderAttributes(
     // default?
     if (topology != "line") {
       if (hsExecMode != spv::ExecutionMode::Max) {
-        theBuilder.addExecutionMode(entryFunctionId, hsExecMode, {});
+        spvBuilder.addExecutionMode(entryFunction, hsExecMode, {},
+                                    decl->getLocation());
       } else {
         emitError("unknown output topology in hull shader",
                   outputTopology->getLocation());
@@ -9512,9 +9527,9 @@ bool SPIRVEmitter::processTessellationShaderAttributes(
   }
   if (auto *controlPoints = decl->getAttr<HLSLOutputControlPointsAttr>()) {
     *numOutputControlPoints = controlPoints->getCount();
-    theBuilder.addExecutionMode(entryFunctionId,
+    spvBuilder.addExecutionMode(entryFunction,
                                 spv::ExecutionMode::OutputVertices,
-                                {*numOutputControlPoints});
+                                {*numOutputControlPoints}, decl->getLocation());
   }
   if (auto *pcf = decl->getAttr<HLSLPatchConstantFuncAttr>()) {
     llvm::StringRef pcf_name = pcf->getFunctionName();
@@ -9549,10 +9564,10 @@ bool SPIRVEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
   // The wrapper entry function surely does not have pre-assigned <result-id>
   // for it like other functions that got added to the work queue following
   // function calls. And the wrapper is the entry function.
-  entryFunctionId =
-      theBuilder.beginFunction(funcType, voidType, decl->getName());
+  entryFunction = spvBuilder.beginFunction(
+      astContext.VoidTy, /*SourceLocation*/ {}, decl->getName());
   // Note this should happen before using declIdMapper for other tasks.
-  declIdMapper.setEntryFunctionId(entryFunctionId);
+  declIdMapper.setEntryFunction(entryFunction);
 
   // Handle attributes specific to each shader stage
   if (shaderModel.IsPS()) {

+ 5 - 1
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -29,6 +29,8 @@
 #include "clang/Frontend/CompilerInstance.h"
 #include "clang/SPIRV/FeatureManager.h"
 #include "clang/SPIRV/ModuleBuilder.h"
+#include "clang/SPIRV/SpirvBuilder.h"
+#include "clang/SPIRV/SpirvContext.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 
@@ -932,8 +934,10 @@ private:
   const hlsl::ShaderModel &shaderModel;
 
   SPIRVContext theContext;
+  SpirvContext spvContext;
   FeatureManager featureManager;
   ModuleBuilder theBuilder;
+  SpirvBuilder spvBuilder;
   TypeTranslator typeTranslator;
   DeclResultIdMapper declIdMapper;
 
@@ -945,7 +949,7 @@ private:
 
   /// <result-id> for the entry function. Initially it is zero and will be reset
   /// when starting to translate the entry function.
-  uint32_t entryFunctionId;
+  SpirvFunction *entryFunction;
   /// The current function under traversal.
   const FunctionDecl *curFunction;
   /// The SPIR-V function parameter for the current this object.

+ 67 - 6
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -29,6 +29,17 @@ SpirvFunction *SpirvBuilder::beginFunction(QualType returnType,
   return function;
 }
 
+SpirvFunction *SpirvBuilder::createFunction(QualType returnType,
+                                            SourceLocation loc,
+                                            llvm::StringRef funcName,
+                                            bool isAlias) {
+  function = new (context) SpirvFunction(
+      returnType, /*id*/ 0, spv::FunctionControlMask::MaskNone, loc, funcName);
+  function->setConstainsAliasComponent(isAlias);
+  module->addFunction(function);
+  return function;
+}
+
 SpirvFunctionParameter *SpirvBuilder::addFnParam(QualType ptrType,
                                                  SourceLocation loc,
                                                  llvm::StringRef name) {
@@ -98,6 +109,17 @@ SpirvComposite *SpirvBuilder::createCompositeConstruct(
   return instruction;
 }
 
+SpirvComposite *SpirvBuilder::createCompositeConstruct(
+    const SpirvType *resultType,
+    llvm::ArrayRef<SpirvInstruction *> constituents, SourceLocation loc) {
+  assert(insertPoint && "null insert point");
+  auto *instruction = new (context)
+      SpirvComposite(/*QualType*/ {}, /*id*/ 0, loc, constituents);
+  instruction->setResultType(resultType);
+  insertPoint->addInstruction(instruction);
+  return instruction;
+}
+
 SpirvCompositeExtract *SpirvBuilder::createCompositeExtract(
     QualType resultType, SpirvInstruction *composite,
     llvm::ArrayRef<uint32_t> indexes, SourceLocation loc) {
@@ -139,6 +161,17 @@ SpirvLoad *SpirvBuilder::createLoad(QualType resultType,
   return instruction;
 }
 
+SpirvLoad *SpirvBuilder::createLoad(const SpirvType *resultType,
+                                    SpirvInstruction *pointer,
+                                    SourceLocation loc) {
+  assert(insertPoint && "null insert point");
+  auto *instruction =
+      new (context) SpirvLoad(/*QualType*/ {}, /*id*/ 0, loc, pointer);
+  instruction->setResultType(resultType);
+  insertPoint->addInstruction(instruction);
+  return instruction;
+}
+
 void SpirvBuilder::createStore(SpirvInstruction *address,
                                SpirvInstruction *value, SourceLocation loc) {
   assert(insertPoint && "null insert point");
@@ -168,6 +201,17 @@ SpirvBuilder::createAccessChain(QualType resultType, SpirvInstruction *base,
   return instruction;
 }
 
+SpirvAccessChain *SpirvBuilder::createAccessChain(
+    const SpirvType *resultType, SpirvInstruction *base,
+    llvm::ArrayRef<SpirvInstruction *> indexes, SourceLocation loc) {
+  assert(insertPoint && "null insert point");
+  auto *instruction = new (context)
+      SpirvAccessChain(/*QualType*/ {}, /*id*/ 0, loc, base, indexes);
+  instruction->setResultType(resultType);
+  insertPoint->addInstruction(instruction);
+  return instruction;
+}
+
 SpirvUnaryOp *SpirvBuilder::createUnaryOp(spv::Op op, QualType resultType,
                                           SpirvInstruction *operand,
                                           SourceLocation loc) {
@@ -696,12 +740,15 @@ SpirvVariable *SpirvBuilder::addStageIOVar(QualType type,
   return var;
 }
 
-SpirvVariable *SpirvBuilder::addStageBuiltinVar(QualType type,
+SpirvVariable *SpirvBuilder::addStageBuiltinVar(const SpirvType *type,
                                                 spv::StorageClass storageClass,
                                                 spv::BuiltIn builtin,
                                                 SourceLocation loc) {
   // Note: We store the underlying type in the variable, *not* the pointer type.
-  auto *var = new (context) SpirvVariable(type, /*id*/ 0, loc, storageClass);
+  // TODO(ehsan): type pointer should be added in lowering the type.
+  auto *var =
+      new (context) SpirvVariable(/*QualType*/ {}, /*id*/ 0, loc, storageClass);
+  var->setResultType(type);
   module->addVariable(var);
 
   // Decorate with the specified Builtin
@@ -725,6 +772,20 @@ SpirvVariable *SpirvBuilder::addModuleVar(
   return var;
 }
 
+SpirvVariable *SpirvBuilder::addModuleVar(
+    const SpirvType *type, spv::StorageClass storageClass, llvm::StringRef name,
+    llvm::Optional<SpirvInstruction *> init, SourceLocation loc) {
+  assert(storageClass != spv::StorageClass::Function);
+  // Note: We store the underlying type in the variable, *not* the pointer type.
+  auto *var =
+      new (context) SpirvVariable(/*QualType*/ {}, /*id*/ 0, loc, storageClass,
+                                  init.hasValue() ? init.getValue() : nullptr);
+  var->setResultType(type);
+  var->setDebugName(name);
+  module->addVariable(var);
+  return var;
+}
+
 void SpirvBuilder::decorateLocation(SpirvInstruction *target, uint32_t location,
                                     SourceLocation srcLoc) {
   auto *decor = new (context)
@@ -767,15 +828,15 @@ void SpirvBuilder::decorateInputAttachmentIndex(SpirvInstruction *target,
   module->addDecoration(decor);
 }
 
-void SpirvBuilder::decorateCounterBufferId(SpirvInstruction *mainBuffer,
-                                           uint32_t counterBufferId,
-                                           SourceLocation srcLoc) {
+void SpirvBuilder::decorateCounterBuffer(SpirvInstruction *mainBuffer,
+                                         SpirvInstruction *counterBuffer,
+                                         SourceLocation srcLoc) {
   if (spirvOptions.enableReflect) {
     addExtension(Extension::GOOGLE_hlsl_functionality1, "SPIR-V reflection",
                  srcLoc);
     auto *decor = new (context) SpirvDecoration(
         srcLoc, mainBuffer, spv::Decoration::HlslCounterBufferGOOGLE,
-        {counterBufferId});
+        {counterBuffer});
     module->addDecoration(decor);
   }
 }

+ 19 - 6
tools/clang/lib/SPIRV/SpirvInstruction.cpp

@@ -86,7 +86,10 @@ SpirvInstruction::SpirvInstruction(Kind k, spv::Op op, QualType astType,
                                    uint32_t id, SourceLocation loc)
     : kind(k), opcode(op), astResultType(astType), resultId(id), srcLoc(loc),
       debugName(), resultType(nullptr), resultTypeId(0),
-      layoutRule(SpirvLayoutRule::Void) {}
+      layoutRule(SpirvLayoutRule::Void), containsAlias(false),
+      storageClass(spv::StorageClass::Max), isRValue_(false),
+      isConstant_(false), isSpecConstant_(false), isRelaxedPrecision_(false),
+      isNonUniform_(false) {}
 
 SpirvCapability::SpirvCapability(SourceLocation loc, spv::Capability cap)
     : SpirvInstruction(IK_Capability, spv::Op::OpCapability, QualType(),
@@ -122,8 +125,7 @@ SpirvEntryPoint::SpirvEntryPoint(SourceLocation loc,
       interfaceVec(iface.begin(), iface.end()) {}
 
 // OpExecutionMode and OpExecutionModeId instructions
-SpirvExecutionMode::SpirvExecutionMode(SourceLocation loc,
-                                       SpirvEntryPoint *entry,
+SpirvExecutionMode::SpirvExecutionMode(SourceLocation loc, SpirvFunction *entry,
                                        spv::ExecutionMode em,
                                        llvm::ArrayRef<uint32_t> paramsVec,
                                        bool usesIdParams)
@@ -163,7 +165,7 @@ SpirvDecoration::SpirvDecoration(SourceLocation loc,
                                       : spv::Op::OpDecorate,
                        /*type*/ {}, /*id*/ 0, loc),
       target(targetInst), decoration(decor), index(idx),
-      params(p.begin(), p.end()) {}
+      params(p.begin(), p.end()), idParams() {}
 
 SpirvDecoration::SpirvDecoration(SourceLocation loc,
                                  SpirvInstruction *targetInst,
@@ -174,17 +176,28 @@ SpirvDecoration::SpirvDecoration(SourceLocation loc,
                        idx.hasValue() ? spv::Op::OpMemberDecorate
                                       : spv::Op::OpDecorate,
                        /*type*/ {}, /*id*/ 0, loc),
-      target(targetInst), decoration(decor), index(idx), params() {
+      target(targetInst), decoration(decor), index(idx), params(), idParams() {
   const auto &stringWords = string::encodeSPIRVString(strParam);
   params.insert(params.end(), stringWords.begin(), stringWords.end());
 }
 
+SpirvDecoration::SpirvDecoration(SourceLocation loc,
+                                 SpirvInstruction *targetInst,
+                                 spv::Decoration decor,
+                                 llvm::ArrayRef<SpirvInstruction *> ids)
+    : SpirvInstruction(IK_Decoration, spv::Op::OpDecorateId,
+                       /*type*/ {}, /*id*/ 0, loc),
+      target(targetInst), decoration(decor), index(llvm::None), params(),
+      idParams(ids.begin(), ids.end()) {}
+
 SpirvVariable::SpirvVariable(QualType resultType, uint32_t resultId,
                              SourceLocation loc, spv::StorageClass sc,
                              SpirvInstruction *initializerInst)
     : SpirvInstruction(IK_Variable, spv::Op::OpVariable, resultType, resultId,
                        loc),
-      storageClass(sc), initializer(initializerInst) {}
+      initializer(initializerInst) {
+  setStorageClass(sc);
+}
 
 SpirvFunctionParameter::SpirvFunctionParameter(QualType resultType,
                                                uint32_t resultId,

+ 4 - 3
tools/clang/lib/SPIRV/SpirvModule.cpp

@@ -14,7 +14,8 @@ namespace clang {
 namespace spirv {
 
 SpirvModule::SpirvModule()
-    : bound(1), shaderModelVersion(0), memoryModel(nullptr), debugSource(nullptr) {}
+    : bound(1), shaderModelVersion(0), memoryModel(nullptr),
+      debugSource(nullptr) {}
 
 bool SpirvModule::invokeVisitor(Visitor *visitor) {
   if (!visitor->visit(this, Visitor::Phase::Init))
@@ -69,7 +70,7 @@ bool SpirvModule::invokeVisitor(Visitor *visitor) {
 
 void SpirvModule::addFunction(SpirvFunction *fn) {
   assert(fn && "cannot add null function to the module");
-  functions.push_back(fn);
+  functions.insert(fn);
 }
 
 void SpirvModule::addCapability(SpirvCapability *cap) {
@@ -82,7 +83,7 @@ void SpirvModule::setMemoryModel(SpirvMemoryModel *model) {
   memoryModel = model;
 }
 
-void SpirvModule::addEntryPoint(SpirvEntryPoint* ep) {
+void SpirvModule::addEntryPoint(SpirvEntryPoint *ep) {
   assert(ep && "cannot add null as an entry point");
   entryPoints.push_back(ep);
 }

+ 18 - 0
tools/clang/lib/SPIRV/SpirvType.cpp

@@ -68,5 +68,23 @@ bool StructType::operator==(const StructType &that) const {
          readOnly == that.readOnly;
 }
 
+HybridStructType::HybridStructType(
+    llvm::ArrayRef<HybridStructType::FieldInfo> fieldsVec, llvm::StringRef name,
+    bool isReadOnly, HybridStructType::InterfaceType iface)
+    : SpirvType(TK_HybridStruct), fields(fieldsVec.begin(), fieldsVec.end()),
+      structName(name), readOnly(isReadOnly), interfaceType(iface) {}
+
+bool HybridStructType::FieldInfo::
+operator==(const HybridStructType::FieldInfo &that) const {
+  return astType == that.astType && spirvType == that.spirvType &&
+         name == that.name && vkOffsetAttr == that.vkOffsetAttr &&
+         packOffsetAttr == that.packOffsetAttr;
+}
+
+bool HybridStructType::operator==(const HybridStructType &that) const {
+  return fields == that.fields && structName == that.structName &&
+         readOnly == that.readOnly;
+}
+
 } // namespace spirv
 } // namespace clang