Procházet zdrojové kódy

[spirv] Add EmitTypeHandler to emit types.

Ehsan před 6 roky
rodič
revize
6c31214960

+ 77 - 5
tools/clang/include/clang/SPIRV/EmitVisitor.h

@@ -8,20 +8,84 @@
 #ifndef LLVM_CLANG_SPIRV_EMITVISITOR_H
 #define LLVM_CLANG_SPIRV_EMITVISITOR_H
 
+#include "clang/SPIRV/SPIRVContext.h"
 #include "clang/SPIRV/SpirvVisitor.h"
+#include "llvm/ADT/DenseMap.h"
+
+#include <functional>
 
 namespace clang {
 namespace spirv {
 
-class SpirvModule;
 class SpirvFunction;
 class SpirvBasicBlock;
+class SpirvType;
+
+// Provides DenseMapInfo for SpirvLayoutRule so that we can use it as key to
+// DenseMap.
+//
+// Mostly from DenseMapInfo<unsigned> in DenseMapInfo.h.
+struct SpirvLayoutRuleDenseMapInfo {
+  static inline SpirvLayoutRule getEmptyKey() { return SpirvLayoutRule::Void; }
+  static inline SpirvLayoutRule getTombstoneKey() {
+    return SpirvLayoutRule::Void;
+  }
+  static unsigned getHashValue(const SpirvLayoutRule &Val) {
+    return static_cast<unsigned>(Val) * 37U;
+  }
+  static bool isEqual(const SpirvLayoutRule &LHS, const SpirvLayoutRule &RHS) {
+    return LHS == RHS;
+  }
+};
+
+class EmitTypeHandler {
+public:
+  EmitTypeHandler(SpirvContext &c, std::vector<uint32_t> *decVec,
+                  std::vector<uint32_t> *typesVec,
+                  const std::function<uint32_t()> &takeNextIdFn)
+      : context(c), annotationsBinary(decVec), typeConstantBinary(typesVec),
+        takeNextIdFunction(takeNextIdFn) {
+    assert(decVec);
+    assert(typesVec);
+  }
+
+  // Disable copy constructor/assignment.
+  EmitTypeHandler(const EmitTypeHandler &) = delete;
+  EmitTypeHandler &operator=(const EmitTypeHandler &) = delete;
+
+  // Emits the instruction for the given type into the typeConstantBinary and
+  // returns the result-id for the type.
+  uint32_t emitType(const SpirvType *, SpirvLayoutRule);
+
+  uint32_t getResultIdForType(const SpirvType *, SpirvLayoutRule,
+                              bool *alreadyExists);
+
+private:
+  void initTypeInstruction(spv::Op op);
+  void finalizeTypeInstruction();
+
+private:
+  SpirvContext &context;
+  std::vector<uint32_t> curTypeInst;
+  std::vector<uint32_t> *annotationsBinary;
+  std::vector<uint32_t> *typeConstantBinary;
+  std::function<uint32_t()> takeNextIdFunction;
+
+  // emittedTypes is a map that caches the <result-id> of types in order to
+  // avoid translating a type multiple times.
+  using LayoutRuleToTypeIdMap =
+      llvm::DenseMap<SpirvLayoutRule, uint32_t, SpirvLayoutRuleDenseMapInfo>;
+  llvm::DenseMap<const SpirvType *, LayoutRuleToTypeIdMap> emittedTypes;
+};
 
 /// \breif The visitor class that emits the SPIR-V words from the in-memory
 /// representation.
 class EmitVisitor : public Visitor {
 public:
-  EmitVisitor() = default;
+  EmitVisitor(const SpirvCodeGenOptions &opts, SpirvContext &ctx)
+      : Visitor(opts, ctx), id(0),
+        typeHandler(ctx, &annotationsBinary, &typeConstantBinary,
+                    [this]() -> uint32_t { return takeNextId(); }) {}
 
   // Visit different SPIR-V constructs for emitting.
   bool visit(SpirvModule *, Phase phase);
@@ -84,8 +148,16 @@ public:
   bool visit(SpirvVectorShuffle *);
 
 private:
+  // Returns the next available result-id.
+  uint32_t takeNextId() { return ++id; }
+
   // Initiates the creation of a new instruction with the given Opcode.
   void initInstruction(spv::Op);
+  // Initiates the creation of the given SPIR-V instruction.
+  // If the given instruction has a return type, it will also trigger emitting
+  // the necessary type (and its associated decorations) and uses its result-id
+  // in the instruction.
+  void initInstruction(SpirvInstruction *);
 
   // Finalizes the current instruction by encoding the instruction size into the
   // first word, and then appends the current instruction to the SPIR-V binary.
@@ -94,9 +166,6 @@ private:
   // Encodes the given string into the current instruction that is being built.
   void encodeString(llvm::StringRef value);
 
-  // Provides the next available <result-id>
-  uint32_t getNextId() { return ++id; }
-
   // Emits an OpName instruction into the debugBinary for the given target.
   void emitDebugNameForInstruction(uint32_t resultId, llvm::StringRef name);
 
@@ -104,7 +173,10 @@ private:
   // using the type information.
 
 private:
+  // The last result-id that's been used so far.
   uint32_t id;
+  // Handler for emitting types and their related instructions.
+  EmitTypeHandler typeHandler;
   // Current instruction being built
   SmallVector<uint32_t, 16> curInst;
   // All preamble instructions in the following order:

+ 16 - 3
tools/clang/include/clang/SPIRV/SPIRVContext.h

@@ -15,6 +15,7 @@
 #include "clang/Frontend/FrontendAction.h"
 #include "clang/SPIRV/Constant.h"
 #include "clang/SPIRV/Decoration.h"
+#include "clang/SPIRV/SpirvInstruction.h"
 #include "clang/SPIRV/SpirvType.h"
 #include "clang/SPIRV/Type.h"
 #include "llvm/ADT/DenseMap.h"
@@ -145,7 +146,7 @@ struct StorageClassDenseMapInfo {
 /// the SPIR-V entities allocated in memory.
 class SpirvContext {
 public:
-  SpirvContext();
+  SpirvContext(const ASTContext &ctx);
   ~SpirvContext() = default;
 
   // Forbid copy construction and assignment
@@ -177,8 +178,9 @@ public:
   const MatrixType *getMatrixType(const SpirvType *vecType, uint32_t vecCount,
                                   bool isRowMajor);
 
-  const ImageType *getImageType(const SpirvType *, spv::Dim, bool arrayed,
-                                bool ms, ImageType::WithSampler sampled,
+  const ImageType *getImageType(const SpirvType *, spv::Dim,
+                                ImageType::WithDepth, bool arrayed, bool ms,
+                                ImageType::WithSampler sampled,
                                 spv::ImageFormat);
   const SamplerType *getSamplerType() const { return samplerType; }
   const SampledImageType *getSampledImageType(const ImageType *image);
@@ -198,7 +200,12 @@ public:
 
   const StructType *getByteAddressBufferType(bool isWritable);
 
+  SpirvConstant *getConstantUint32(uint32_t value, SourceLocation loc = {});
+  // TODO: Add getConstant* methods for other types.
+
 private:
+  const ASTContext &astContext;
+
   /// \brief The allocator used to create SPIR-V entity objects.
   ///
   /// SPIR-V entity objects are never destructed; rather, all memory associated
@@ -245,6 +252,12 @@ private:
   llvm::DenseMap<const SpirvType *, SCToPtrTyMap> pointerTypes;
 
   llvm::SmallVector<const FunctionType *, 8> functionTypes;
+
+  // Unique constants
+  // Avoid premature optimiztion: we do a linear search to find an existing
+  // constant (if any). This can be done faster if we use maps or use different
+  // vectors based on the constant type.
+  llvm::SmallVector<SpirvConstant *, 8> constants;
 };
 
 } // end namespace spirv

+ 27 - 4
tools/clang/include/clang/SPIRV/SpirvFunction.h

@@ -42,13 +42,29 @@ public:
   // TODO: The responsibility of assigning the result-id of a function shouldn't
   // be on the function itself.
   uint32_t getResultId() const { return functionId; }
+
   // TODO: There should be a pass for lowering QualType to SPIR-V type,
   // and this method should be able to return the result-id of the SPIR-V type.
   // Both the return type of the function as well as the SPIR-V "function type"
   // are needed. SPIR-V function type (obtained by OpFunctionType) includes both
   // the return type as well as argument types.
-  uint32_t getReturnTypeId() const { return 0; }
-  uint32_t getFunctionTypeId() const { return 0; }
+  uint32_t getReturnTypeId() const { return returnTypeId; }
+  void setReturnTypeId(uint32_t id) { returnTypeId = id; }
+
+  // Sets the lowered (SPIR-V) function type.
+  void setReturnType(SpirvType *type) { returnType = type; }
+  // Returns the lowered (SPIR-V) function type.
+  const SpirvType *getReturnType() const { return returnType; }
+
+  // Sets the SPIR-V type of the function
+  void setFunctionType(FunctionType *type) { fnType = type; }
+  // Returns the SPIR-V type of the function
+  FunctionType *getFunctionType() const { return fnType; }
+
+  // Sets the result-id of the OpTypeFunction
+  void setFunctionTypeId(uint32_t id) { fnTypeId = id; }
+  // Returns the result-id of the OpTypeFunction
+  uint32_t getFunctionTypeId() const { return fnTypeId; }
 
   void setFunctionName(llvm::StringRef name) { functionName = name; }
   llvm::StringRef getFunctionName() const { return functionName; }
@@ -58,8 +74,15 @@ public:
   void addBasicBlock(SpirvBasicBlock *);
 
 private:
-  QualType functionType;                    ///< This function's type
-  uint32_t functionId;                      ///< This function's <result-id>
+  uint32_t functionId; ///< This function's <result-id>
+
+  QualType astReturnType; ///< The return type
+  SpirvType *returnType;  ///< The lowered return type
+  uint32_t returnTypeId;  ///< result-id for the return type
+
+  FunctionType *fnType; ///< The SPIR-V function type
+  uint32_t fnTypeId;    ///< result-id for the SPIR-V function type
+
   spv::FunctionControlMask functionControl; ///< SPIR-V function control
   SourceLocation functionLoc;               ///< Location in source code
   std::string functionName;                 ///< This function's name

+ 19 - 11
tools/clang/include/clang/SPIRV/SpirvInstruction.h

@@ -24,6 +24,7 @@ class Visitor;
 class SpirvBasicBlock;
 class SpirvFunction;
 class SpirvVariable;
+class SpirvType;
 
 /// \brief The base class for representing SPIR-V instructions.
 class SpirvInstruction {
@@ -116,23 +117,27 @@ public:
 
   Kind getKind() const { return kind; }
   spv::Op getopcode() const { return opcode; }
-  QualType getResultType() const { return resultType; }
+  QualType getAstResultType() const { return astResultType; }
 
-  // TODO: The QualType should be lowered to a SPIR-V type and the result-id of
-  // the SPIR-V type should be stored somewhere (either in SpirvInstruction or
-  // in a map in SpirvModule). The id of the result type should be retreived and
-  // returned by this method.
-  uint32_t getResultTypeId() const { return 0; }
+  uint32_t getResultTypeId() const { return resultTypeId; }
+  void setResultTypeId(uint32_t id) { resultTypeId = id; }
+
+  bool hasResultType() const { return resultType != nullptr; }
+  SpirvType *getResultType() const { return resultType; }
 
   // TODO: The responsibility of assigning the result-id of an instruction
   // shouldn't be on the instruction itself.
   uint32_t getResultId() const { return resultId; }
+  void setResultId(uint32_t id) { resultId = id; }
 
   clang::SourceLocation getSourceLocation() const { return srcLoc; }
 
   void setDebugName(llvm::StringRef name) { debugName = name; }
   llvm::StringRef getDebugName() const { return debugName; }
 
+  SpirvLayoutRule getLayoutRule() const { return layoutRule; }
+  void setLayoutRule(SpirvLayoutRule rule) { layoutRule = rule; }
+
 protected:
   // Forbid creating SpirvInstruction directly
   SpirvInstruction(Kind kind, spv::Op opcode, QualType resultType,
@@ -142,10 +147,13 @@ private:
   const Kind kind;
 
   spv::Op opcode;
-  QualType resultType;
+  QualType astResultType;
   uint32_t resultId;
   SourceLocation srcLoc;
   std::string debugName;
+  SpirvType *resultType;
+  uint32_t resultTypeId;
+  SpirvLayoutRule layoutRule;
 };
 
 #define DECLARE_INVOKE_VISITOR_FOR_CLASS(cls)                                  \
@@ -515,7 +523,7 @@ public:
 
   // Returns all possible basic blocks that could be taken by the branching
   // instruction.
-  llvm::ArrayRef<SpirvBasicBlock *> getTargetBranches() const {
+  llvm::ArrayRef<SpirvBasicBlock *> getTargetBranches() const override {
     return {targetLabel};
   }
 
@@ -537,7 +545,7 @@ public:
 
   DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvBranchConditional)
 
-  llvm::ArrayRef<SpirvBasicBlock *> getTargetBranches() const {
+  llvm::ArrayRef<SpirvBasicBlock *> getTargetBranches() const override {
     return {trueLabel, falseLabel};
   }
 
@@ -606,7 +614,7 @@ public:
   // Returns the branch label that will be taken for the given literal.
   SpirvBasicBlock *getTargetLabelForLiteral(uint32_t) const;
   // Returns all possible branches that could be taken by the switch statement.
-  llvm::ArrayRef<SpirvBasicBlock *> getTargetBranches() const;
+  llvm::ArrayRef<SpirvBasicBlock *> getTargetBranches() const override;
 
 private:
   SpirvInstruction *selector;
@@ -956,7 +964,7 @@ public:
 
   uint32_t getBitwidth() const { return bitwidth; }
   void setBitwidth(uint32_t width) { bitwidth = width; }
-  bool isSigned() const { return getResultType()->isSignedIntegerType(); }
+  bool isSigned() const { return getAstResultType()->isSignedIntegerType(); }
 
 private:
   uint32_t bitwidth;

+ 3 - 2
tools/clang/include/clang/SPIRV/SpirvModule.h

@@ -69,7 +69,7 @@ public:
   void addExecutionMode(SpirvExecutionMode *);
 
   // Adds an extension to the module.
-  void addExtension(SpirvExtension*);
+  void addExtension(SpirvExtension *);
 
   // Adds an extended instruction set to the module.
   void addExtInstSet(SpirvExtInstImport *);
@@ -87,9 +87,10 @@ public:
   void setShaderModelVersion(uint32_t v) { shaderModelVersion = v; }
   void setSourceFileName(llvm::StringRef name) { sourceFileName = name; }
   void setSourceFileContent(llvm::StringRef c) { sourceFileContent = c; }
+  void setBound(uint32_t b) { bound = b; }
 
 private:
-  uint32_t bound; ///< The <result-id> bound: the next unused one
+  uint32_t bound;
   uint32_t shaderModelVersion;
 
   // "Metadata" instructions

+ 51 - 2
tools/clang/include/clang/SPIRV/SpirvType.h

@@ -91,6 +91,9 @@ public:
 
   static bool classof(const SpirvType *t) { return t->getKind() == TK_Integer; }
 
+  uint32_t getBitwidth() const { return bitwidth; }
+  bool isSignedInt() const { return isSigned; }
+
 private:
   uint32_t bitwidth;
   bool isSigned;
@@ -102,6 +105,8 @@ public:
 
   static bool classof(const SpirvType *t) { return t->getKind() == TK_Float; }
 
+  uint32_t getBitwidth() const { return bitwidth; }
+
 private:
   uint32_t bitwidth;
 };
@@ -113,6 +118,11 @@ public:
 
   static bool classof(const SpirvType *t) { return t->getKind() == TK_Vector; }
 
+  const SpirvType *getElementType() const {
+    return llvm::cast<SpirvType>(elementType);
+  }
+  uint32_t getElementCount() const { return elementCount; }
+
 private:
   const ScalarType *elementType;
   uint32_t elementCount;
@@ -126,6 +136,12 @@ public:
 
   bool operator==(const MatrixType &that) const;
 
+  const SpirvType *getVecType() const {
+    return llvm::cast<SpirvType>(vectorType);
+  }
+  uint32_t getVecCount() const { return vectorCount; }
+  bool isRowMajorMat() const { return isRowMajor; }
+
 private:
   const VectorType *vectorType;
   uint32_t vectorCount;
@@ -144,17 +160,34 @@ public:
     Yes = 1,
     No = 2,
   };
+  enum class WithDepth : uint32_t {
+    No = 0,
+    Yes = 1,
+    Unknown = 2,
+  };
 
-  ImageType(const NumericalType *sampledType, spv::Dim, bool isArrayed,
-            bool isMultiSampled, WithSampler sampled, spv::ImageFormat);
+  ImageType(const NumericalType *sampledType, spv::Dim, WithDepth depth,
+            bool isArrayed, bool isMultiSampled, WithSampler sampled,
+            spv::ImageFormat);
 
   static bool classof(const SpirvType *t) { return t->getKind() == TK_Image; }
 
   bool operator==(const ImageType &that) const;
 
+  const SpirvType *getSampledType() const {
+    return llvm::cast<SpirvType>(sampledType);
+  }
+  spv::Dim getDimension() const { return dimension; }
+  WithDepth getDepth() const { return imageDepth; }
+  bool isArrayedImage() const { return isArrayed; }
+  bool isMSImage() const { return isMultiSampled; }
+  WithSampler withSampler() const { return isSampled; }
+  spv::ImageFormat getImageFormat() const { return imageFormat; }
+
 private:
   const NumericalType *sampledType;
   spv::Dim dimension;
+  WithDepth imageDepth;
   bool isArrayed;
   bool isMultiSampled;
   WithSampler isSampled;
@@ -177,6 +210,8 @@ public:
     return t->getKind() == TK_SampledImage;
   }
 
+  const ImageType *getImageType() const { return imageType; }
+
 private:
   const ImageType *imageType;
 };
@@ -186,6 +221,9 @@ public:
   ArrayType(const SpirvType *elemType, uint32_t elemCount)
       : SpirvType(TK_Array), elementType(elemType), elementCount(elemCount) {}
 
+  const SpirvType *getElementType() const { return elementType; }
+  uint32_t getElementCount() const { return elementCount; }
+
   static bool classof(const SpirvType *t) { return t->getKind() == TK_Array; }
 
 private:
@@ -202,6 +240,8 @@ public:
     return t->getKind() == TK_RuntimeArray;
   }
 
+  const SpirvType *getElementType() const { return elementType; }
+
 private:
   const SpirvType *elementType;
 };
@@ -215,6 +255,9 @@ public:
   static bool classof(const SpirvType *t) { return t->getKind() == TK_Struct; }
 
   bool isReadOnly() const { return readOnly; }
+  std::string getStructName() const { return structName; }
+  llvm::ArrayRef<const SpirvType *> getFieldTypes() const { return fieldTypes; }
+  llvm::ArrayRef<std::string> getFieldNames() const { return fieldNames; }
 
   bool operator==(const StructType &that) const;
 
@@ -236,6 +279,9 @@ public:
 
   static bool classof(const SpirvType *t) { return t->getKind() == TK_Pointer; }
 
+  const SpirvType *getPointeeType() const { return pointeeType; }
+  spv::StorageClass getStorageClass() const { return storageClass; }
+
 private:
   const SpirvType *pointeeType;
   spv::StorageClass storageClass;
@@ -255,6 +301,9 @@ public:
     return returnType == that.returnType && paramTypes == that.paramTypes;
   }
 
+  const SpirvType *getReturnType() const { return returnType; }
+  llvm::ArrayRef<const SpirvType *> getParamTypes() const { return paramTypes; }
+
 private:
   const SpirvType *returnType;
   llvm::SmallVector<const SpirvType *, 8> paramTypes;

+ 10 - 7
tools/clang/include/clang/SPIRV/SpirvVisitor.h

@@ -14,6 +14,7 @@
 namespace clang {
 namespace spirv {
 
+class SpirvContext;
 class SpirvModule;
 class SpirvFunction;
 class SpirvBasicBlock;
@@ -40,19 +41,19 @@ public:
   Visitor &operator=(Visitor &&) = delete;
 
   // Visiting different SPIR-V constructs.
-  virtual bool visit(SpirvModule *, Phase) {}
-  virtual bool visit(SpirvFunction *, Phase) {}
-  virtual bool visit(SpirvBasicBlock *, Phase) {}
+  virtual bool visit(SpirvModule *, Phase) { return true; }
+  virtual bool visit(SpirvFunction *, Phase) { return true; }
+  virtual bool visit(SpirvBasicBlock *, Phase) { return true; }
 
   /// The "sink" visit function for all instructions.
   ///
   /// By default, all other visit instructions redirect to this visit function.
   /// So that you want override this visit function to handle all instructions,
   /// regardless of their polymorphism.
-  virtual bool visitInstruction(SpirvInstruction *) {}
+  virtual bool visitInstruction(SpirvInstruction *) { return true; }
 
 #define DEFINE_VISIT_METHOD(cls)                                               \
-  virtual bool visit(cls *i) { visitInstruction(i); }
+  virtual bool visit(cls *i) { return visitInstruction(i); }
 
   DEFINE_VISIT_METHOD(SpirvCapability)
   DEFINE_VISIT_METHOD(SpirvExtension)
@@ -114,12 +115,14 @@ public:
 #undef DEFINE_VISIT_METHOD
 
 protected:
-  explicit Visitor(const SpirvCodeGenOptions &opts) : spvOptions(opts) {}
+  explicit Visitor(const SpirvCodeGenOptions &opts, SpirvContext &ctx)
+      : spvOptions(opts), context(ctx) {}
 
   const SpirvCodeGenOptions &getCodeGenOptions() const { return spvOptions; }
 
-private:
+protected:
   const SpirvCodeGenOptions &spvOptions;
+  SpirvContext &context;
 };
 
 } // namespace spirv

+ 275 - 56
tools/clang/lib/SPIRV/EmitVisitor.cpp

@@ -13,6 +13,7 @@
 #include "clang/SPIRV/SpirvFunction.h"
 #include "clang/SPIRV/SpirvInstruction.h"
 #include "clang/SPIRV/SpirvModule.h"
+#include "clang/SPIRV/SpirvType.h"
 #include "clang/SPIRV/String.h"
 
 namespace {
@@ -63,6 +64,16 @@ void EmitVisitor::emitDebugNameForInstruction(uint32_t resultId,
   debugBinary.insert(debugBinary.end(), curInst.begin(), curInst.end());
 }
 
+void EmitVisitor::initInstruction(SpirvInstruction *inst) {
+  if (inst->hasResultType()) {
+    const uint32_t resultTypeId =
+        typeHandler.emitType(inst->getResultType(), inst->getLayoutRule());
+    inst->setResultTypeId(resultTypeId);
+  }
+  curInst.clear();
+  curInst.push_back(static_cast<uint32_t>(inst->getopcode()));
+}
+
 void EmitVisitor::initInstruction(spv::Op op) {
   curInst.clear();
   curInst.push_back(static_cast<uint32_t>(op));
@@ -113,7 +124,11 @@ void EmitVisitor::encodeString(llvm::StringRef value) {
 }
 
 bool EmitVisitor::visit(SpirvModule *m, Phase phase) {
-  // No pre or post ops for SpirvModule.
+  // No pre-visit operations needed for SpirvModule.
+
+  if (phase == Visitor::Phase::Done)
+    m->setBound(takeNextId());
+
   return true;
 }
 
@@ -122,13 +137,20 @@ bool EmitVisitor::visit(SpirvFunction *fn, Phase phase) {
 
   // Before emitting the function
   if (phase == Visitor::Phase::Init) {
+    const uint32_t returnTypeId =
+        typeHandler.emitType(fn->getReturnType(), SpirvLayoutRule::Void);
+    const uint32_t functionTypeId =
+        typeHandler.emitType(fn->getFunctionType(), SpirvLayoutRule::Void);
+    fn->setReturnTypeId(returnTypeId);
+    fn->setFunctionTypeId(functionTypeId);
+
     // Emit OpFunction
     initInstruction(spv::Op::OpFunction);
-    curInst.push_back(fn->getReturnTypeId());
+    curInst.push_back(returnTypeId);
     curInst.push_back(fn->getResultId());
     curInst.push_back(
         static_cast<uint32_t>(spv::FunctionControlMask::MaskNone));
-    curInst.push_back(fn->getFunctionTypeId());
+    curInst.push_back(functionTypeId);
     finalizeInstruction();
     emitDebugNameForInstruction(fn->getResultId(), fn->getFunctionName());
   }
@@ -161,21 +183,21 @@ bool EmitVisitor::visit(SpirvBasicBlock *bb, Phase phase) {
 }
 
 bool EmitVisitor::visit(SpirvCapability *cap) {
-  initInstruction(cap->getopcode());
+  initInstruction(cap);
   curInst.push_back(static_cast<uint32_t>(cap->getCapability()));
   finalizeInstruction();
   return true;
 }
 
 bool EmitVisitor::visit(SpirvExtension *ext) {
-  initInstruction(ext->getopcode());
+  initInstruction(ext);
   encodeString(ext->getExtensionName());
   finalizeInstruction();
   return true;
 }
 
 bool EmitVisitor::visit(SpirvExtInstImport *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getResultId());
   encodeString(inst->getExtendedInstSetName());
   finalizeInstruction();
@@ -183,7 +205,7 @@ bool EmitVisitor::visit(SpirvExtInstImport *inst) {
 }
 
 bool EmitVisitor::visit(SpirvMemoryModel *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(static_cast<uint32_t>(inst->getAddressingModel()));
   curInst.push_back(static_cast<uint32_t>(inst->getMemoryModel()));
   finalizeInstruction();
@@ -191,7 +213,7 @@ bool EmitVisitor::visit(SpirvMemoryModel *inst) {
 }
 
 bool EmitVisitor::visit(SpirvEntryPoint *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(static_cast<uint32_t>(inst->getExecModel()));
   curInst.push_back(inst->getEntryPoint()->getResultId());
   encodeString(inst->getEntryPointName());
@@ -203,7 +225,7 @@ bool EmitVisitor::visit(SpirvEntryPoint *inst) {
 }
 
 bool EmitVisitor::visit(SpirvExecutionMode *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getEntryPoint()->getResultId());
   curInst.push_back(static_cast<uint32_t>(inst->getExecutionMode()));
   curInst.insert(curInst.end(), inst->getParams().begin(),
@@ -213,7 +235,7 @@ bool EmitVisitor::visit(SpirvExecutionMode *inst) {
 }
 
 bool EmitVisitor::visit(SpirvString *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getResultId());
   encodeString(inst->getString());
   finalizeInstruction();
@@ -221,7 +243,7 @@ bool EmitVisitor::visit(SpirvString *inst) {
 }
 
 bool EmitVisitor::visit(SpirvSource *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(static_cast<uint32_t>(inst->getSourceLanguage()));
   curInst.push_back(static_cast<uint32_t>(inst->getVersion()));
   if (inst->hasFile())
@@ -239,14 +261,14 @@ bool EmitVisitor::visit(SpirvSource *inst) {
 }
 
 bool EmitVisitor::visit(SpirvModuleProcessed *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   encodeString(inst->getProcess());
   finalizeInstruction();
   return true;
 }
 
 bool EmitVisitor::visit(SpirvDecoration *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getTarget()->getResultId());
   if (inst->isMemberDecoration())
     curInst.push_back(inst->getMemberIndex());
@@ -258,7 +280,7 @@ bool EmitVisitor::visit(SpirvDecoration *inst) {
 }
 
 bool EmitVisitor::visit(SpirvVariable *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getResultTypeId());
   curInst.push_back(inst->getResultId());
   curInst.push_back(static_cast<uint32_t>(inst->getStorageClass()));
@@ -270,7 +292,7 @@ bool EmitVisitor::visit(SpirvVariable *inst) {
 }
 
 bool EmitVisitor::visit(SpirvFunctionParameter *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getResultTypeId());
   curInst.push_back(inst->getResultId());
   finalizeInstruction();
@@ -279,7 +301,7 @@ bool EmitVisitor::visit(SpirvFunctionParameter *inst) {
 }
 
 bool EmitVisitor::visit(SpirvLoopMerge *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getMergeBlock()->getLabelId());
   curInst.push_back(inst->getContinueTarget()->getLabelId());
   curInst.push_back(static_cast<uint32_t>(inst->getLoopControlMask()));
@@ -289,7 +311,7 @@ bool EmitVisitor::visit(SpirvLoopMerge *inst) {
 }
 
 bool EmitVisitor::visit(SpirvSelectionMerge *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getMergeBlock()->getLabelId());
   curInst.push_back(static_cast<uint32_t>(inst->getSelectionControlMask()));
   finalizeInstruction();
@@ -298,7 +320,7 @@ bool EmitVisitor::visit(SpirvSelectionMerge *inst) {
 }
 
 bool EmitVisitor::visit(SpirvBranch *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getTargetLabel()->getLabelId());
   finalizeInstruction();
   emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
@@ -306,7 +328,7 @@ bool EmitVisitor::visit(SpirvBranch *inst) {
 }
 
 bool EmitVisitor::visit(SpirvBranchConditional *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getCondition()->getResultId());
   curInst.push_back(inst->getTrueLabel()->getLabelId());
   curInst.push_back(inst->getFalseLabel()->getLabelId());
@@ -316,21 +338,21 @@ bool EmitVisitor::visit(SpirvBranchConditional *inst) {
 }
 
 bool EmitVisitor::visit(SpirvKill *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   finalizeInstruction();
   emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
   return true;
 }
 
 bool EmitVisitor::visit(SpirvReturn *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   finalizeInstruction();
   emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
   return true;
 }
 
 bool EmitVisitor::visit(SpirvSwitch *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getSelector()->getResultId());
   curInst.push_back(inst->getDefaultLabel()->getLabelId());
   for (const auto &target : inst->getTargets()) {
@@ -343,14 +365,14 @@ bool EmitVisitor::visit(SpirvSwitch *inst) {
 }
 
 bool EmitVisitor::visit(SpirvUnreachable *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   finalizeInstruction();
   emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
   return true;
 }
 
 bool EmitVisitor::visit(SpirvAccessChain *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getResultTypeId());
   curInst.push_back(inst->getResultId());
   curInst.push_back(inst->getBase()->getResultId());
@@ -363,7 +385,7 @@ bool EmitVisitor::visit(SpirvAccessChain *inst) {
 
 bool EmitVisitor::visit(SpirvAtomic *inst) {
   const auto op = inst->getopcode();
-  initInstruction(op);
+  initInstruction(inst);
   if (op != spv::Op::OpAtomicStore && op != spv::Op::OpAtomicFlagClear) {
     curInst.push_back(inst->getResultTypeId());
     curInst.push_back(inst->getResultId());
@@ -383,7 +405,7 @@ bool EmitVisitor::visit(SpirvAtomic *inst) {
 }
 
 bool EmitVisitor::visit(SpirvBarrier *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   if (inst->isControlBarrier())
     curInst.push_back(static_cast<uint32_t>(inst->getExecutionScope()));
   curInst.push_back(static_cast<uint32_t>(inst->getMemoryScope()));
@@ -394,7 +416,7 @@ bool EmitVisitor::visit(SpirvBarrier *inst) {
 }
 
 bool EmitVisitor::visit(SpirvBinaryOp *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getResultTypeId());
   curInst.push_back(inst->getResultId());
   curInst.push_back(inst->getOperand1()->getResultId());
@@ -405,7 +427,7 @@ bool EmitVisitor::visit(SpirvBinaryOp *inst) {
 }
 
 bool EmitVisitor::visit(SpirvBitFieldExtract *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getResultTypeId());
   curInst.push_back(inst->getResultId());
   curInst.push_back(inst->getBase()->getResultId());
@@ -417,7 +439,7 @@ bool EmitVisitor::visit(SpirvBitFieldExtract *inst) {
 }
 
 bool EmitVisitor::visit(SpirvBitFieldInsert *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getResultTypeId());
   curInst.push_back(inst->getResultId());
   curInst.push_back(inst->getBase()->getResultId());
@@ -430,7 +452,7 @@ bool EmitVisitor::visit(SpirvBitFieldInsert *inst) {
 }
 
 bool EmitVisitor::visit(SpirvConstantBoolean *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getResultTypeId());
   curInst.push_back(inst->getResultId());
   finalizeInstruction();
@@ -439,7 +461,7 @@ bool EmitVisitor::visit(SpirvConstantBoolean *inst) {
 }
 
 bool EmitVisitor::visit(SpirvConstantInteger *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getResultTypeId());
   curInst.push_back(inst->getResultId());
   // 16-bit cases
@@ -481,7 +503,7 @@ bool EmitVisitor::visit(SpirvConstantInteger *inst) {
 }
 
 bool EmitVisitor::visit(SpirvConstantFloat *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getResultTypeId());
   curInst.push_back(inst->getResultId());
   if (inst->getBitwidth() == 16) {
@@ -509,7 +531,7 @@ bool EmitVisitor::visit(SpirvConstantFloat *inst) {
 }
 
 bool EmitVisitor::visit(SpirvConstantComposite *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getResultTypeId());
   curInst.push_back(inst->getResultId());
   for (const auto constituent : inst->getConstituents())
@@ -520,7 +542,7 @@ bool EmitVisitor::visit(SpirvConstantComposite *inst) {
 }
 
 bool EmitVisitor::visit(SpirvConstantNull *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getResultTypeId());
   curInst.push_back(inst->getResultId());
   finalizeInstruction();
@@ -529,7 +551,7 @@ bool EmitVisitor::visit(SpirvConstantNull *inst) {
 }
 
 bool EmitVisitor::visit(SpirvComposite *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getResultTypeId());
   curInst.push_back(inst->getResultId());
   for (const auto constituent : inst->getConstituents())
@@ -540,7 +562,7 @@ bool EmitVisitor::visit(SpirvComposite *inst) {
 }
 
 bool EmitVisitor::visit(SpirvCompositeExtract *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getResultTypeId());
   curInst.push_back(inst->getResultId());
   curInst.push_back(inst->getComposite()->getResultId());
@@ -552,7 +574,7 @@ bool EmitVisitor::visit(SpirvCompositeExtract *inst) {
 }
 
 bool EmitVisitor::visit(SpirvCompositeInsert *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getResultTypeId());
   curInst.push_back(inst->getResultId());
   curInst.push_back(inst->getObject()->getResultId());
@@ -565,19 +587,19 @@ bool EmitVisitor::visit(SpirvCompositeInsert *inst) {
 }
 
 bool EmitVisitor::visit(SpirvEmitVertex *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   finalizeInstruction();
   return true;
 }
 
 bool EmitVisitor::visit(SpirvEndPrimitive *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   finalizeInstruction();
   return true;
 }
 
 bool EmitVisitor::visit(SpirvExtInst *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getResultTypeId());
   curInst.push_back(inst->getResultId());
   curInst.push_back(inst->getInstructionSet()->getResultId());
@@ -590,7 +612,7 @@ bool EmitVisitor::visit(SpirvExtInst *inst) {
 }
 
 bool EmitVisitor::visit(SpirvFunctionCall *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getResultTypeId());
   curInst.push_back(inst->getResultId());
   curInst.push_back(inst->getFunction()->getResultId());
@@ -602,7 +624,7 @@ bool EmitVisitor::visit(SpirvFunctionCall *inst) {
 }
 
 bool EmitVisitor::visit(SpirvNonUniformBinaryOp *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getResultTypeId());
   curInst.push_back(inst->getResultId());
   curInst.push_back(static_cast<uint32_t>(inst->getExecutionScope()));
@@ -614,7 +636,7 @@ bool EmitVisitor::visit(SpirvNonUniformBinaryOp *inst) {
 }
 
 bool EmitVisitor::visit(SpirvNonUniformElect *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getResultTypeId());
   curInst.push_back(inst->getResultId());
   curInst.push_back(static_cast<uint32_t>(inst->getExecutionScope()));
@@ -624,7 +646,7 @@ bool EmitVisitor::visit(SpirvNonUniformElect *inst) {
 }
 
 bool EmitVisitor::visit(SpirvNonUniformUnaryOp *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getResultTypeId());
   curInst.push_back(inst->getResultId());
   curInst.push_back(static_cast<uint32_t>(inst->getExecutionScope()));
@@ -637,7 +659,7 @@ bool EmitVisitor::visit(SpirvNonUniformUnaryOp *inst) {
 }
 
 bool EmitVisitor::visit(SpirvImageOp *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
 
   if (!inst->isImageWrite()) {
     curInst.push_back(inst->getResultTypeId());
@@ -681,7 +703,7 @@ bool EmitVisitor::visit(SpirvImageOp *inst) {
 }
 
 bool EmitVisitor::visit(SpirvImageQuery *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getResultTypeId());
   curInst.push_back(inst->getResultId());
   curInst.push_back(inst->getImage()->getResultId());
@@ -695,7 +717,7 @@ bool EmitVisitor::visit(SpirvImageQuery *inst) {
 }
 
 bool EmitVisitor::visit(SpirvImageSparseTexelsResident *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getResultTypeId());
   curInst.push_back(inst->getResultId());
   curInst.push_back(inst->getResidentCode()->getResultId());
@@ -705,7 +727,7 @@ bool EmitVisitor::visit(SpirvImageSparseTexelsResident *inst) {
 }
 
 bool EmitVisitor::visit(SpirvImageTexelPointer *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getResultTypeId());
   curInst.push_back(inst->getResultId());
   curInst.push_back(inst->getImage()->getResultId());
@@ -717,7 +739,7 @@ bool EmitVisitor::visit(SpirvImageTexelPointer *inst) {
 }
 
 bool EmitVisitor::visit(SpirvLoad *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getResultTypeId());
   curInst.push_back(inst->getResultId());
   curInst.push_back(inst->getPointer()->getResultId());
@@ -729,7 +751,7 @@ bool EmitVisitor::visit(SpirvLoad *inst) {
 }
 
 bool EmitVisitor::visit(SpirvSampledImage *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getResultTypeId());
   curInst.push_back(inst->getResultId());
   curInst.push_back(inst->getImage()->getResultId());
@@ -740,7 +762,7 @@ bool EmitVisitor::visit(SpirvSampledImage *inst) {
 }
 
 bool EmitVisitor::visit(SpirvSelect *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getResultTypeId());
   curInst.push_back(inst->getResultId());
   curInst.push_back(inst->getCondition()->getResultId());
@@ -752,7 +774,7 @@ bool EmitVisitor::visit(SpirvSelect *inst) {
 }
 
 bool EmitVisitor::visit(SpirvSpecConstantBinaryOp *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getResultTypeId());
   curInst.push_back(inst->getResultId());
   curInst.push_back(static_cast<uint32_t>(inst->getSpecConstantopcode()));
@@ -764,7 +786,7 @@ bool EmitVisitor::visit(SpirvSpecConstantBinaryOp *inst) {
 }
 
 bool EmitVisitor::visit(SpirvSpecConstantUnaryOp *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getResultTypeId());
   curInst.push_back(inst->getResultId());
   curInst.push_back(static_cast<uint32_t>(inst->getSpecConstantopcode()));
@@ -775,7 +797,7 @@ bool EmitVisitor::visit(SpirvSpecConstantUnaryOp *inst) {
 }
 
 bool EmitVisitor::visit(SpirvStore *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getPointer()->getResultId());
   curInst.push_back(inst->getObject()->getResultId());
   if (inst->hasMemoryAccessSemantics())
@@ -786,7 +808,7 @@ bool EmitVisitor::visit(SpirvStore *inst) {
 }
 
 bool EmitVisitor::visit(SpirvUnaryOp *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getResultTypeId());
   curInst.push_back(inst->getResultId());
   curInst.push_back(inst->getOperand()->getResultId());
@@ -796,7 +818,7 @@ bool EmitVisitor::visit(SpirvUnaryOp *inst) {
 }
 
 bool EmitVisitor::visit(SpirvVectorShuffle *inst) {
-  initInstruction(inst->getopcode());
+  initInstruction(inst);
   curInst.push_back(inst->getResultTypeId());
   curInst.push_back(inst->getResultId());
   curInst.push_back(inst->getVec1()->getResultId());
@@ -808,5 +830,202 @@ bool EmitVisitor::visit(SpirvVectorShuffle *inst) {
   return true;
 }
 
+// EmitTypeHandler ------
+
+void EmitTypeHandler::initTypeInstruction(spv::Op op) {
+  curTypeInst.clear();
+  curTypeInst.push_back(static_cast<uint32_t>(op));
+}
+
+void EmitTypeHandler::finalizeTypeInstruction() {
+  curTypeInst[0] |= static_cast<uint32_t>(curTypeInst.size()) << 16;
+  typeConstantBinary->insert(typeConstantBinary->end(), curTypeInst.begin(),
+                             curTypeInst.end());
+}
+
+uint32_t EmitTypeHandler::getResultIdForType(const SpirvType *type,
+                                             SpirvLayoutRule rule,
+                                             bool *alreadyExists) {
+  assert(alreadyExists);
+  // Check if the type has already been emitted.
+  auto foundType = emittedTypes.find(type);
+  if (foundType != emittedTypes.end()) {
+    auto foundLayoutRule = foundType->second.find(rule);
+    if (foundLayoutRule != foundType->second.end()) {
+      *alreadyExists = true;
+      return foundLayoutRule->second;
+    }
+  }
+
+  *alreadyExists = false;
+  const uint32_t id = takeNextIdFunction();
+  emittedTypes[type][rule] = id;
+  return id;
+}
+
+uint32_t EmitTypeHandler::emitType(const SpirvType *type,
+                                   SpirvLayoutRule rule) {
+  //
+  // TODO: This method is currently missing decorations for types completely.
+  //
+
+  bool alreadyExists = false;
+  const uint32_t id = getResultIdForType(type, rule, &alreadyExists);
+
+  // If the type has already been emitted, we just need to return its
+  // <result-id>.
+  if (alreadyExists)
+    return id;
+
+  if (isa<VoidType>(type)) {
+    initTypeInstruction(spv::Op::OpTypeVoid);
+    curTypeInst.push_back(id);
+    finalizeTypeInstruction();
+  }
+  // Boolean types
+  else if (isa<BoolType>(type)) {
+    initTypeInstruction(spv::Op::OpTypeBool);
+    curTypeInst.push_back(id);
+    finalizeTypeInstruction();
+  }
+  // Integer types
+  else if (const auto *intType = dyn_cast<IntegerType>(type)) {
+    initTypeInstruction(spv::Op::OpTypeInt);
+    curTypeInst.push_back(id);
+    curTypeInst.push_back(intType->getBitwidth());
+    curTypeInst.push_back(intType->isSignedInt() ? 1 : 0);
+    finalizeTypeInstruction();
+  }
+  // Float types
+  else if (const auto *floatType = dyn_cast<FloatType>(type)) {
+    initTypeInstruction(spv::Op::OpTypeFloat);
+    curTypeInst.push_back(id);
+    curTypeInst.push_back(floatType->getBitwidth());
+    finalizeTypeInstruction();
+  }
+  // Vector types
+  else if (const auto *vecType = dyn_cast<VectorType>(type)) {
+    const uint32_t elementTypeId = emitType(vecType->getElementType(), rule);
+    initTypeInstruction(spv::Op::OpTypeVector);
+    curTypeInst.push_back(id);
+    curTypeInst.push_back(elementTypeId);
+    curTypeInst.push_back(vecType->getElementCount());
+    finalizeTypeInstruction();
+  }
+  // Matrix types
+  else if (const auto *matType = dyn_cast<MatrixType>(type)) {
+    const uint32_t vecTypeId = emitType(matType->getVecType(), rule);
+    initTypeInstruction(spv::Op::OpTypeMatrix);
+    curTypeInst.push_back(id);
+    curTypeInst.push_back(vecTypeId);
+    curTypeInst.push_back(matType->getVecCount());
+    finalizeTypeInstruction();
+    // Note that RowMajor and ColMajor decorations only apply to structure
+    // members, and should not be handled here.
+  }
+  // Image types
+  else if (const auto *imageType = dyn_cast<ImageType>(type)) {
+    const uint32_t sampledTypeId = emitType(imageType->getSampledType(), rule);
+    initTypeInstruction(spv::Op::OpTypeImage);
+    curTypeInst.push_back(id);
+    curTypeInst.push_back(sampledTypeId);
+    curTypeInst.push_back(static_cast<uint32_t>(imageType->getDimension()));
+    curTypeInst.push_back(static_cast<uint32_t>(imageType->getDepth()));
+    curTypeInst.push_back(imageType->isArrayedImage() ? 1 : 0);
+    curTypeInst.push_back(imageType->isMSImage() ? 1 : 0);
+    curTypeInst.push_back(static_cast<uint32_t>(imageType->withSampler()));
+    curTypeInst.push_back(static_cast<uint32_t>(imageType->getImageFormat()));
+    finalizeTypeInstruction();
+  }
+  // Sampler types
+  else if (const auto *samplerType = dyn_cast<SamplerType>(type)) {
+    initTypeInstruction(spv::Op::OpTypeSampler);
+    curTypeInst.push_back(id);
+    finalizeTypeInstruction();
+  }
+  // SampledImage types
+  else if (const auto *sampledImageType = dyn_cast<SampledImageType>(type)) {
+    const uint32_t imageTypeId =
+        emitType(sampledImageType->getImageType(), rule);
+    initTypeInstruction(spv::Op::OpTypeSampledImage);
+    curTypeInst.push_back(id);
+    curTypeInst.push_back(imageTypeId);
+    finalizeTypeInstruction();
+  }
+  // Array types
+  else if (const auto *arrayType = dyn_cast<ArrayType>(type)) {
+    // Emit the OpConstant instruction that is needed to get the result-id for
+    // the array length.
+    SpirvConstant *constant =
+        context.getConstantUint32(arrayType->getElementCount());
+    if (constant->getResultId() == 0) {
+      constant->setResultId(takeNextIdFunction());
+    }
+    IntegerType constantIntType(32, 0);
+    const uint32_t uint32TypeId = emitType(&constantIntType, rule);
+    initTypeInstruction(spv::Op::OpConstant);
+    curTypeInst.push_back(uint32TypeId);
+    curTypeInst.push_back(constant->getResultId());
+    curTypeInst.push_back(arrayType->getElementCount());
+    finalizeTypeInstruction();
+
+    // Emit the OpTypeArray instruction
+    const uint32_t elemTypeId = emitType(arrayType->getElementType(), rule);
+    initTypeInstruction(spv::Op::OpTypeArray);
+    curTypeInst.push_back(id);
+    curTypeInst.push_back(elemTypeId);
+    curTypeInst.push_back(constant->getResultId());
+    finalizeTypeInstruction();
+  }
+  // RuntimeArray types
+  else if (const auto *raType = dyn_cast<RuntimeArrayType>(type)) {
+    const uint32_t elemTypeId = emitType(raType->getElementType(), rule);
+    initTypeInstruction(spv::Op::OpTypeRuntimeArray);
+    curTypeInst.push_back(id);
+    curTypeInst.push_back(elemTypeId);
+    finalizeTypeInstruction();
+  }
+  // Structure types
+  else if (const auto *structType = dyn_cast<StructType>(type)) {
+    llvm::SmallVector<uint32_t, 4> fieldTypeIds;
+    for (auto *fieldType : structType->getFieldTypes())
+      fieldTypeIds.push_back(emitType(fieldType, rule));
+    initTypeInstruction(spv::Op::OpTypeStruct);
+    curTypeInst.push_back(id);
+    for (auto fieldTypeId : fieldTypeIds)
+      curTypeInst.push_back(fieldTypeId);
+    finalizeTypeInstruction();
+  }
+  // Pointer types
+  else if (const auto *ptrType = dyn_cast<SpirvPointerType>(type)) {
+    const uint32_t pointeeType = emitType(ptrType->getPointeeType(), rule);
+    initTypeInstruction(spv::Op::OpTypePointer);
+    curTypeInst.push_back(id);
+    curTypeInst.push_back(static_cast<uint32_t>(ptrType->getStorageClass()));
+    curTypeInst.push_back(pointeeType);
+    finalizeTypeInstruction();
+  }
+  // Function types
+  else if (const auto *fnType = dyn_cast<FunctionType>(type)) {
+    const uint32_t retTypeId = emitType(fnType->getReturnType(), rule);
+    llvm::SmallVector<uint32_t, 4> paramTypeIds;
+    for (auto *paramType : fnType->getParamTypes())
+      paramTypeIds.push_back(emitType(paramType, rule));
+
+    initTypeInstruction(spv::Op::OpTypeFunction);
+    curTypeInst.push_back(id);
+    curTypeInst.push_back(retTypeId);
+    for (auto paramTypeId : paramTypeIds)
+      curTypeInst.push_back(paramTypeId);
+    finalizeTypeInstruction();
+  }
+  // Unhandled types
+  else {
+    llvm_unreachable("unhandled type in emitType");
+  }
+
+  return id;
+}
+
 } // end namespace spirv
 } // end namespace clang

+ 6 - 5
tools/clang/lib/SPIRV/LowerTypeVisitor.cpp

@@ -235,8 +235,8 @@ const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type,
       const auto sampledType = hlsl::GetHLSLResourceResultType(type);
       return spvContext.getImageType(
           lowerType(getElementType(sampledType, srcLoc), rule, srcLoc), dim,
-          isArray, isMS, ImageType::WithSampler::Yes,
-          spv::ImageFormat::Unknown);
+          ImageType::WithDepth::Unknown, isArray, isMS,
+          ImageType::WithSampler::Yes, spv::ImageFormat::Unknown);
     }
 
     // There is no RWTexture3DArray
@@ -250,7 +250,7 @@ const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type,
           translateSampledTypeToImageFormat(sampledType, srcLoc);
       return spvContext.getImageType(
           lowerType(getElementType(sampledType, srcLoc), rule, srcLoc), dim,
-          isArray,
+          ImageType::WithDepth::Unknown, isArray,
           /*isMultiSampled=*/false, /*sampled=*/ImageType::WithSampler::No,
           format);
     }
@@ -334,7 +334,7 @@ const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type,
     const auto format = translateSampledTypeToImageFormat(sampledType, srcLoc);
     return spvContext.getImageType(
         lowerType(getElementType(sampledType, srcLoc), rule, srcLoc),
-        spv::Dim::Buffer,
+        spv::Dim::Buffer, ImageType::WithDepth::Unknown,
         /*isArrayed=*/false, /*isMultiSampled=*/false,
         /*sampled*/ name == "Buffer" ? ImageType::WithSampler::Yes
                                      : ImageType::WithSampler::No,
@@ -365,7 +365,8 @@ const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type,
     const auto sampledType = hlsl::GetHLSLResourceResultType(type);
     return spvContext.getImageType(
         lowerType(getElementType(sampledType, srcLoc), rule, srcLoc),
-        spv::Dim::SubpassData, /*isArrayed=*/false,
+        spv::Dim::SubpassData, ImageType::WithDepth::Unknown,
+        /*isArrayed=*/false,
         /*isMultipleSampled=*/name == "SubpassInputMS",
         ImageType::WithSampler::No, spv::ImageFormat::Unknown);
   }

+ 1 - 1
tools/clang/lib/SPIRV/LowerTypeVisitor.h

@@ -23,7 +23,7 @@ class LowerTypeVisitor : public Visitor {
 public:
   LowerTypeVisitor(ASTContext &astCtx, SpirvContext &spvCtx,
                    const SpirvCodeGenOptions &opts)
-      : Visitor(opts), astContext(astCtx), spvContext(spvCtx) {}
+      : Visitor(opts, spvCtx), astContext(astCtx), spvContext(spvCtx) {}
 
 private:
   /// Emits error to the diagnostic engine associated with this visitor.

+ 23 - 6
tools/clang/lib/SPIRV/SPIRVContext.cpp

@@ -74,9 +74,9 @@ const Decoration *SPIRVContext::registerDecoration(const Decoration &d) {
   return &*it;
 }
 
-SpirvContext::SpirvContext()
-    : allocator(), voidType(nullptr), boolType(nullptr), sintTypes({}),
-      uintTypes({}), floatTypes({}), samplerType(nullptr) {
+SpirvContext::SpirvContext(const ASTContext &ctx)
+    : astContext(ctx), allocator(), voidType(nullptr), boolType(nullptr),
+      sintTypes({}), uintTypes({}), floatTypes({}), samplerType(nullptr) {
   voidType = new (this) VoidType;
   boolType = new (this) BoolType;
   samplerType = new (this) SamplerType;
@@ -164,7 +164,9 @@ const MatrixType *SpirvContext::getMatrixType(const SpirvType *elemType,
 }
 
 const ImageType *SpirvContext::getImageType(const SpirvType *sampledType,
-                                            spv::Dim dim, bool arrayed, bool ms,
+                                            spv::Dim dim,
+                                            ImageType::WithDepth depth,
+                                            bool arrayed, bool ms,
                                             ImageType::WithSampler sampled,
                                             spv::ImageFormat format) {
   // We are certain this should be a numerical type. Otherwise, cast causes an
@@ -172,7 +174,7 @@ const ImageType *SpirvContext::getImageType(const SpirvType *sampledType,
   const NumericalType *elemType = cast<NumericalType>(sampledType);
 
   // Create a temporary object for finding in the vector.
-  ImageType type(elemType, dim, arrayed, ms, sampled, format);
+  ImageType type(elemType, dim, depth, arrayed, ms, sampled, format);
 
   auto found = std::find_if(
       imageTypes.begin(), imageTypes.end(),
@@ -182,7 +184,7 @@ const ImageType *SpirvContext::getImageType(const SpirvType *sampledType,
     return *found;
 
   imageTypes.push_back(
-      new (this) ImageType(elemType, dim, arrayed, ms, sampled, format));
+      new (this) ImageType(elemType, dim, depth, arrayed, ms, sampled, format));
 
   return imageTypes.back();
 }
@@ -290,5 +292,20 @@ const StructType *SpirvContext::getByteAddressBufferType(bool isWritable) {
                        {}, !isWritable);
 }
 
+SpirvConstant *SpirvContext::getConstantUint32(uint32_t value,
+                                               SourceLocation loc) {
+  for (auto *constant : constants)
+    if (auto *intConst = dyn_cast<SpirvConstantInteger>(constant))
+      if (!intConst->isSigned() && intConst->getBitwidth() == 32 &&
+          intConst->getUnsignedInt32Value() == value)
+        return constant;
+
+  // Couldn't find the constant. Create one.
+  SpirvConstant *intConst =
+      new (this) SpirvConstantInteger(value, astContext.UnsignedIntTy, 0, loc);
+  constants.push_back(intConst);
+  return intConst;
+}
+
 } // end namespace spirv
 } // end namespace clang

+ 3 - 2
tools/clang/lib/SPIRV/SpirvFunction.cpp

@@ -16,8 +16,9 @@ namespace spirv {
 SpirvFunction::SpirvFunction(QualType type, uint32_t id,
                              spv::FunctionControlMask control,
                              SourceLocation loc, llvm::StringRef name)
-    : functionType(type), functionId(id), functionControl(control),
-      functionLoc(loc), functionName(name) {}
+    : functionId(id), astReturnType(type), returnType(nullptr), returnTypeId(0),
+      fnType(nullptr), fnTypeId(0), functionControl(control), functionLoc(loc),
+      functionName(name) {}
 
 bool SpirvFunction::invokeVisitor(Visitor *visitor) {
   if (!visitor->visit(this, Visitor::Phase::Init))

+ 11 - 9
tools/clang/lib/SPIRV/SpirvInstruction.cpp

@@ -10,10 +10,10 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "clang/SPIRV/SpirvInstruction.h"
 #include "clang/SPIRV/BitwiseCast.h"
 #include "clang/SPIRV/SpirvBasicBlock.h"
 #include "clang/SPIRV/SpirvFunction.h"
-#include "clang/SPIRV/SpirvInstruction.h"
 #include "clang/SPIRV/SpirvVisitor.h"
 #include "clang/SPIRV/String.h"
 
@@ -83,7 +83,9 @@ DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvVectorShuffle)
 
 SpirvInstruction::SpirvInstruction(Kind k, spv::Op op, QualType type,
                                    uint32_t id, SourceLocation loc)
-    : kind(k), opcode(op), resultType(type), resultId(id), srcLoc(loc) {}
+    : kind(k), opcode(op), astResultType(type), resultId(id), srcLoc(loc),
+      debugName(), resultType(nullptr), resultTypeId(0),
+      layoutRule(SpirvLayoutRule::Void) {}
 
 SpirvCapability::SpirvCapability(SourceLocation loc, spv::Capability cap)
     : SpirvInstruction(IK_Capability, spv::Op::OpCapability, QualType(),
@@ -156,11 +158,11 @@ SpirvDecoration::SpirvDecoration(SourceLocation loc,
                                  llvm::ArrayRef<uint32_t> p,
                                  llvm::Optional<uint32_t> idx)
     : SpirvInstruction(IK_Decoration,
-                       index.hasValue() ? spv::Op::OpMemberDecorate
-                                        : spv::Op::OpDecorate,
+                       idx.hasValue() ? spv::Op::OpMemberDecorate
+                                      : spv::Op::OpDecorate,
                        /*type*/ {}, /*id*/ 0, loc),
-      target(targetInst), decoration(decor), params(p.begin(), p.end()),
-      index(idx) {}
+      target(targetInst), decoration(decor), index(idx),
+      params(p.begin(), p.end()) {}
 
 SpirvDecoration::SpirvDecoration(SourceLocation loc,
                                  SpirvInstruction *targetInst,
@@ -168,10 +170,10 @@ SpirvDecoration::SpirvDecoration(SourceLocation loc,
                                  llvm::StringRef strParam,
                                  llvm::Optional<uint32_t> idx)
     : SpirvInstruction(IK_Decoration,
-                       index.hasValue() ? spv::Op::OpMemberDecorate
-                                        : spv::Op::OpDecorate,
+                       idx.hasValue() ? spv::Op::OpMemberDecorate
+                                      : spv::Op::OpDecorate,
                        /*type*/ {}, /*id*/ 0, loc),
-      target(targetInst), decoration(decor), params(), index(idx) {
+      target(targetInst), decoration(decor), index(idx), params() {
   const auto &stringWords = string::encodeSPIRVString(strParam);
   params.insert(params.end(), stringWords.begin(), stringWords.end());
 }

+ 4 - 3
tools/clang/lib/SPIRV/SpirvType.cpp

@@ -37,9 +37,10 @@ bool MatrixType::operator==(const MatrixType &that) const {
          isRowMajor == that.isRowMajor;
 }
 
-ImageType::ImageType(const NumericalType *type, spv::Dim dim, bool arrayed,
-                     bool ms, WithSampler sampled, spv::ImageFormat format)
-    : SpirvType(TK_Image), sampledType(type), dimension(dim),
+ImageType::ImageType(const NumericalType *type, spv::Dim dim, WithDepth depth,
+                     bool arrayed, bool ms, WithSampler sampled,
+                     spv::ImageFormat format)
+    : SpirvType(TK_Image), sampledType(type), dimension(dim), imageDepth(depth),
       isArrayed(arrayed), isMultiSampled(ms), isSampled(sampled),
       imageFormat(format) {}