浏览代码

[spirv] Refine ModuleBuilder and translate float4() function structure (#408)

* [spirv] Refine SPIR-V module builder

* Use Type and Decoration for storing types and decorations
* Use vector of unique pointers to store basic blocks and functions
* Remove unnecessary beginModule() and endModule() methods
* Remove unnecessary error status in module builder
* Allow constructing multiple basic blocks at the same time
* More structured representations inside spirv::Module
* Enable spirv::Type and spirv::Decoration to convert to SPIR-V words

* [spirv] Translate float4() function structure

* float and float{2|3|4} types are handled.
* float4() function type are handled.
* Basic function structure are handled.
Lei Zhang 8 年之前
父节点
当前提交
87a448b43e

+ 4 - 0
tools/clang/include/clang/SPIRV/Decoration.h

@@ -137,6 +137,10 @@ public:
             memberIndex.getValue() == other.memberIndex.getValue());
   }
 
+  // \brief Construct the SPIR-V words for this decoration with the given
+  // target <result-id>.
+  std::vector<uint32_t> withTargetId(uint32_t targetId) const;
+
 private:
   /// \brief prevent public APIs from creating Decoration objects.
   Decoration(spv::Decoration dec_id, llvm::SmallVector<uint32_t, 2> arg = {},

+ 63 - 39
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -9,10 +9,12 @@
 #ifndef LLVM_CLANG_SPIRV_MODULEBUILDER_H
 #define LLVM_CLANG_SPIRV_MODULEBUILDER_H
 
+#include <memory>
+
 #include "clang/SPIRV/InstBuilder.h"
 #include "clang/SPIRV/SPIRVContext.h"
 #include "clang/SPIRV/Structure.h"
-#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/MapVector.h"
 
 namespace clang {
 namespace spirv {
@@ -20,55 +22,77 @@ namespace spirv {
 /// \brief SPIR-V module builder.
 ///
 /// This class exports API for constructing SPIR-V binary interactively.
-/// Call beginModule() to start building a SPIR-V module and endModule()
-/// to complete building a module. Call takeModule() to get the final
-/// SPIR-V binary.
+/// At any time, there can only exist at most one function under building;
+/// but there can exist multiple basic blocks under construction.
 class ModuleBuilder {
 public:
-  enum class Status {
-    Success,
-    ErrNestedModule,       ///< Tried to create module inside module
-    ErrNestedFunction,     ///< Tried to create function inside function
-    ErrNestedBasicBlock,   ///< Tried to crate basic block inside basic block
-    ErrDetachedBasicBlock, ///< Tried to create basic block out of function
-    ErrNoActiveFunction,   ///< Tried to finish building non existing function
-    ErrActiveBasicBlock,   ///< Tried to finish building function when there are
-                           ///< active basic block
-    ErrNoActiveBasicBlock, ///< Tried to insert instructions without active
-                           ///< basic block
-  };
-
   /// \brief Constructs a ModuleBuilder with the given SPIR-V context.
   explicit ModuleBuilder(SPIRVContext *);
 
-  /// \brief Begins building a SPIR-V module.
-  Status beginModule();
-  /// \brief Ends building the current module.
-  Status endModule();
-  /// \brief Begins building a SPIR-V function.
-  Status beginFunction(uint32_t funcType, uint32_t returnType);
-  /// \brief Ends building the current function.
-  Status endFunction();
-  /// \brief Begins building a SPIR-V basic block.
-  Status beginBasicBlock();
-  /// \brief Ends building the current SPIR-V basic block with OpReturn.
-  Status endBasicBlockWithReturn();
-
-  /// \brief Takes the SPIR-V module under building.
+  /// \brief Begins building a SPIR-V function. At any time, there can only
+  /// exist at most one function under building. Returns the <result-id> for the
+  /// function on success. Returns zero on failure.
+  uint32_t beginFunction(uint32_t funcType, uint32_t returnType);
+  /// \brief Ends building of the current function. Returns true of success,
+  /// false on failure. All basic blocks constructed from the beginning or
+  /// after ending the previous function will be collected into this function.
+  bool endFunction();
+
+  /// \brief Creates a SPIR-V basic block. On success, returns the <label-id>
+  /// for the basic block. On failure, returns zero.
+  uint32_t bbCreate();
+  /// \brief Ends building the SPIR-V basic block having the given <label-id>
+  /// with OpReturn. Returns true on success, false on failure.
+  bool bbReturn(uint32_t labelId);
+
+  /// \brief Sets insertion point to the basic block with the given <label-id>.
+  /// Returns true on success, false on failure.
+  bool setInsertPoint(uint32_t labelId);
+
+  inline void requireCapability(spv::Capability);
+
+  inline void setAddressingModel(spv::AddressingModel);
+  inline void setMemoryModel(spv::MemoryModel);
+
+  uint32_t getVoidType();
+  uint32_t getFloatType();
+  uint32_t getVec2Type(uint32_t elemType);
+  uint32_t getVec3Type(uint32_t elemType);
+  uint32_t getVec4Type(uint32_t elemType);
+  uint32_t getFunctionType(uint32_t returnType,
+                           const std::vector<uint32_t> &paramTypes);
+
+  /// \brief Takes the SPIR-V module under building. This will consume the
+  /// module under construction.
   std::vector<uint32_t> takeModule();
 
 private:
-  /// \brief Ends building the current basic block.
-  Status endBasicBlock();
-
-  SPIRVContext &theContext;                 ///< The SPIR-V context.
-  SPIRVModule theModule;                    ///< The module under building.
-  llvm::Optional<Function> theFunction;     ///< The function under building.
-  llvm::Optional<BasicBlock> theBasicBlock; ///< The basic block under building.
-  std::vector<uint32_t> constructSite;      ///< InstBuilder construction site.
+  /// \brief Map from basic blocks' <label-id> to their structured
+  /// representation.
+  using OrderedBasicBlockMap =
+      llvm::MapVector<uint32_t, std::unique_ptr<BasicBlock>>;
+
+  SPIRVContext &theContext;              ///< The SPIR-V context.
+  SPIRVModule theModule;                 ///< The module under building.
+  std::unique_ptr<Function> theFunction; ///< The function under building.
+  OrderedBasicBlockMap basicBlocks;      ///< The basic blocks under building.
+  BasicBlock *insertPoint;               ///< The current insertion point.
+  std::vector<uint32_t> constructSite;   ///< InstBuilder construction site.
   InstBuilder instBuilder;
 };
 
+void ModuleBuilder::setAddressingModel(spv::AddressingModel am) {
+  theModule.setAddressingModel(am);
+}
+
+void ModuleBuilder::setMemoryModel(spv::MemoryModel mm) {
+  theModule.setMemoryModel(mm);
+}
+
+void ModuleBuilder::requireCapability(spv::Capability cap) {
+  theModule.addCapability(cap);
+}
+
 } // end namespace spirv
 } // end namespace clang
 

+ 93 - 33
tools/clang/include/clang/SPIRV/Structure.h

@@ -18,11 +18,14 @@
 #ifndef LLVM_CLANG_SPIRV_STRUCTURE_H
 #define LLVM_CLANG_SPIRV_STRUCTURE_H
 
+#include <memory>
 #include <string>
 #include <vector>
 
 #include "spirv/1.0/spirv.hpp11"
 #include "clang/SPIRV/InstBuilder.h"
+#include "clang/SPIRV/Type.h"
+#include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/Optional.h"
 
 namespace clang {
@@ -97,7 +100,7 @@ public:
   /// \brief Adds a parameter to this function.
   inline void addParameter(uint32_t paramResultType, uint32_t paramResultId);
   /// \brief Adds a basic block to this function.
-  inline void addBasicBlock(BasicBlock &&block);
+  inline void addBasicBlock(std::unique_ptr<BasicBlock> block);
 
 private:
   uint32_t resultType;
@@ -106,7 +109,7 @@ private:
   uint32_t funcType;
   /// Parameter <result-type> and <result-id> pairs.
   std::vector<std::pair<uint32_t, uint32_t>> parameters;
-  std::vector<BasicBlock> blocks;
+  std::vector<std::unique_ptr<BasicBlock>> blocks;
 };
 
 /// \brief The class representing a SPIR-V module.
@@ -129,14 +132,14 @@ public:
   /// an empty module.
   void clear();
 
-  /// \brief Sets the id bound to the given bound.
-  inline void setBound(uint32_t newBound);
-
   /// \brief Collects all the SPIR-V words in this module and consumes them
   /// using the consumer within the given InstBuilder. This method is
   /// destructive; the module will be consumed and cleared after calling it.
   void take(InstBuilder *builder);
 
+  /// \brief Sets the id bound to the given bound.
+  inline void setBound(uint32_t newBound);
+
   inline void addCapability(spv::Capability);
   inline void addExtension(std::string extension);
   inline void addExtInstSet(uint32_t setId, std::string extInstSet);
@@ -146,12 +149,15 @@ public:
                             std::string targetName,
                             std::initializer_list<uint32_t> intefaces);
   inline void addExecutionMode(Instruction &&);
+  // TODO: source code debug information
   inline void addDebugName(uint32_t targetId,
                            llvm::Optional<uint32_t> memberIndex,
                            std::string name);
-  inline void addDecoration(Instruction &&);
-  inline void addType(Instruction &&);
-  inline void addFunction(Function &&);
+  inline void addDecoration(const Decoration &decoration, uint32_t targetId);
+  inline void addType(const Type *type, uint32_t resultId);
+  inline void addConstant(const Type &type, Instruction &&constant);
+  // TODO: global variables
+  inline void addFunction(std::unique_ptr<Function>);
 
 private:
   /// \brief The struct representing a SPIR-V module header.
@@ -162,11 +168,19 @@ private:
     /// \brief Feeds the consumer with all the SPIR-V words for this header.
     void collect(const WordConsumer &consumer);
 
-    uint32_t magicNumber;
-    uint32_t version;
-    uint32_t generator;
+    const uint32_t magicNumber;
+    const uint32_t version;
+    const uint32_t generator;
     uint32_t bound;
-    uint32_t reserved;
+    const uint32_t reserved;
+  };
+
+  /// \brief The struct representing an extended instruction set.
+  struct ExtInstSet {
+    inline ExtInstSet(uint32_t id, std::string name);
+
+    const uint32_t resultId;
+    const std::string setName;
   };
 
   /// \brief The struct representing an entry point.
@@ -174,10 +188,10 @@ private:
     inline EntryPoint(spv::ExecutionModel, uint32_t id, std::string name,
                       std::initializer_list<uint32_t> interface);
 
-    spv::ExecutionModel executionModel;
-    uint32_t targetId;
-    std::string targetName;
-    std::initializer_list<uint32_t> interfaces;
+    const spv::ExecutionModel executionModel;
+    const uint32_t targetId;
+    const std::string targetName;
+    const std::initializer_list<uint32_t> interfaces;
   };
 
   /// \brief The struct representing a debug name.
@@ -185,27 +199,56 @@ private:
     inline DebugName(uint32_t id, llvm::Optional<uint32_t> index,
                      std::string targetName);
 
-    uint32_t targetId;
-    llvm::Optional<uint32_t> memberIndex;
-    std::string name;
+    const uint32_t targetId;
+    const llvm::Optional<uint32_t> memberIndex;
+    const std::string name;
+  };
+
+  /// \brief The struct representing a deocoration and its target <result-id>.
+  struct DecorationIdPair {
+    inline DecorationIdPair(const Decoration &decor, uint32_t id);
+
+    const Decoration &decoration;
+    const uint32_t targetId;
+  };
+
+  /// \brief The struct representing a type and its <result-id>.
+  struct TypeIdPair {
+    inline TypeIdPair(const Type &ty, uint32_t id);
+
+    const Type &type;
+    const uint32_t resultId;
+  };
+
+  /// \brief The struct representing a constant and its type.
+  struct Constant {
+    inline Constant(const Type &ty, Instruction &&value);
+    const Type &type;
+    Instruction constant;
   };
 
   Header header; ///< SPIR-V module header.
   std::vector<spv::Capability> capabilities;
   std::vector<std::string> extensions;
-  std::vector<std::pair<uint32_t, std::string>> extInstSets;
+  std::vector<ExtInstSet> extInstSets;
+  // addressing and memory model must exist for a valid SPIR-V module.
+  // We make them optional here just to provide extra flexibility of
+  // the representation.
   llvm::Optional<spv::AddressingModel> addressingModel;
   llvm::Optional<spv::MemoryModel> memoryModel;
   std::vector<EntryPoint> entryPoints;
-  // XXX: Right now the following are basically vectors of Instructions.
-  // They will be turned into vectors of more full-fledged classes gradually
-  // as we implement more features.
   std::vector<Instruction> executionModes;
-  // TODO: support other debug instructions
+  // TODO: source code debug information
   std::vector<DebugName> debugNames;
-  std::vector<Instruction> decorations;
-  std::vector<Instruction> typesValues;
-  std::vector<Function> functions;
+  std::vector<DecorationIdPair> decorations;
+  // Note that types and constants are interdependent; Types like arrays have
+  // <result-id>s for constants in their definition, and constants all have
+  // their corresponding types. We store types and constants separately, but
+  // they should be handled together.
+  llvm::MapVector<const Type *, uint32_t> types;
+  std::vector<Constant> constants;
+  // TODO: global variables
+  std::vector<std::unique_ptr<Function>> functions;
 };
 
 BasicBlock::BasicBlock() : labelId(0) {}
@@ -241,7 +284,7 @@ void Function::addParameter(uint32_t rType, uint32_t rId) {
   parameters.emplace_back(rType, rId);
 }
 
-void Function::addBasicBlock(BasicBlock &&block) {
+void Function::addBasicBlock(std::unique_ptr<BasicBlock> block) {
   blocks.push_back(std::move(block));
 }
 
@@ -279,16 +322,23 @@ void SPIRVModule::addDebugName(uint32_t targetId,
                                std::string name) {
   debugNames.emplace_back(targetId, memberIndex, std::move(name));
 }
-void SPIRVModule::addDecoration(Instruction &&decoration) {
-  decorations.push_back(std::move(decoration));
+void SPIRVModule::addDecoration(const Decoration &decoration,
+                                uint32_t targetId) {
+  decorations.emplace_back(decoration, targetId);
 }
-void SPIRVModule::addType(Instruction &&type) {
-  typesValues.push_back(std::move(type));
+void SPIRVModule::addType(const Type *type, uint32_t resultId) {
+  types.insert(std::make_pair(type, resultId));
 }
-void SPIRVModule::addFunction(Function &&f) {
+void SPIRVModule::addConstant(const Type &type, Instruction &&constant) {
+  constants.emplace_back(type, std::move(constant));
+};
+void SPIRVModule::addFunction(std::unique_ptr<Function> f) {
   functions.push_back(std::move(f));
 }
 
+SPIRVModule::ExtInstSet::ExtInstSet(uint32_t id, std::string name)
+    : resultId(id), setName(name) {}
+
 SPIRVModule::EntryPoint::EntryPoint(spv::ExecutionModel em, uint32_t id,
                                     std::string name,
                                     std::initializer_list<uint32_t> interface)
@@ -299,6 +349,16 @@ SPIRVModule::DebugName::DebugName(uint32_t id, llvm::Optional<uint32_t> index,
                                   std::string targetName)
     : targetId(id), memberIndex(index), name(std::move(targetName)) {}
 
+SPIRVModule::DecorationIdPair::DecorationIdPair(const Decoration &decor,
+                                                uint32_t id)
+    : decoration(decor), targetId(id) {}
+
+SPIRVModule::TypeIdPair::TypeIdPair(const Type &ty, uint32_t id)
+    : type(ty), resultId(id) {}
+
+SPIRVModule::Constant::Constant(const Type &ty, Instruction &&value)
+    : type(ty), constant(std::move(value)) {}
+
 } // end namespace spirv
 } // end namespace clang
 

+ 4 - 1
tools/clang/include/clang/SPIRV/Type.h

@@ -100,7 +100,7 @@ public:
                                 spv::StorageClass storage_class, uint32_t type,
                                 DecorationSet decs = {});
   static const Type *getFunction(SPIRVContext &ctx, uint32_t return_type,
-                                 std::initializer_list<uint32_t> params,
+                                 const std::vector<uint32_t> &params,
                                  DecorationSet decs = {});
   static const Type *getEvent(SPIRVContext &ctx, DecorationSet decs = {});
   static const Type *getDeviceEvent(SPIRVContext &ctx, DecorationSet decs = {});
@@ -116,6 +116,9 @@ public:
            decorations == other.decorations;
   }
 
+  // \brief Construct the SPIR-V words for this type with the given <result-id>.
+  std::vector<uint32_t> withResultId(uint32_t resultId) const;
+
 private:
   /// \brief Private constructor.
   Type(spv::Op op, std::vector<uint32_t> arg = {},

+ 21 - 0
tools/clang/lib/SPIRV/Decoration.cpp

@@ -281,5 +281,26 @@ Decoration::getSecondaryViewportRelativeNV(SPIRVContext &context,
   return getUniqueDecoration(context, d);
 }
 
+std::vector<uint32_t> Decoration::withTargetId(uint32_t targetId) const {
+  std::vector<uint32_t> words;
+
+  // TODO: we are essentially duplicate the work InstBuilder is responsible for.
+  // Should figure out a way to unify them.
+  words.reserve(3 + args.size() + (memberIndex.hasValue() ? 1 : 0));
+  if (memberIndex.hasValue()) {
+    words.push_back(static_cast<uint32_t>(spv::Op::OpMemberDecorate));
+    words.push_back(targetId);
+    words.push_back(*memberIndex);
+  } else {
+    words.push_back(static_cast<uint32_t>(spv::Op::OpDecorate));
+    words.push_back(targetId);
+  }
+  words.push_back(static_cast<uint32_t>(id));
+  words.insert(words.end(), args.begin(), args.end());
+  words.front() |= static_cast<uint32_t>(words.size()) << 16;
+
+  return words;
+}
+
 } // end namespace spirv
 } // end namespace clang

+ 85 - 14
tools/clang/lib/SPIRV/EmitSPIRVAction.cpp

@@ -23,25 +23,96 @@
 
 namespace clang {
 namespace {
-class SPIRVEmitter : public ASTConsumer,
-                     public RecursiveASTVisitor<SPIRVEmitter> {
+
+class SPIRVEmitter : public ASTConsumer {
 public:
-  explicit SPIRVEmitter(raw_ostream *Out)
-      : OutStream(*Out), TheContext(), Builder(&TheContext) {}
-
-  void HandleTranslationUnit(ASTContext &Context) override {
-    Builder.beginModule();
-    Builder.endModule();
-    std::vector<uint32_t> M = Builder.takeModule();
-    OutStream.write(reinterpret_cast<const char *>(M.data()), M.size() * 4);
+  explicit SPIRVEmitter(raw_ostream *out)
+      : outStream(*out), theContext(), theBuilder(&theContext) {}
+
+  void HandleTranslationUnit(ASTContext &context) override {
+    theBuilder.requireCapability(spv::Capability::Shader);
+
+    // Addressing and memory model are required in a valid SPIR-V module.
+    theBuilder.setAddressingModel(spv::AddressingModel::Logical);
+    theBuilder.setMemoryModel(spv::MemoryModel::GLSL450);
+
+    // Process all top level Decls.
+    for (auto *decl : context.getTranslationUnitDecl()->decls()) {
+      doDecl(decl);
+    }
+
+    // Output the constructed module.
+    std::vector<uint32_t> m = theBuilder.takeModule();
+    outStream.write(reinterpret_cast<const char *>(m.data()), m.size() * 4);
+  }
+
+  void doDecl(Decl *decl) {
+    if (auto *funcDecl = dyn_cast<FunctionDecl>(decl)) {
+      doFunctionDecl(funcDecl);
+    }
+    // TODO: provide diagnostics of unimplemented features instead of silently
+    // ignoring them here.
+  }
+
+  void doFunctionDecl(FunctionDecl *decl) {
+    const uint32_t funcType = translateFunctionType(decl);
+    const uint32_t retType = translateType(decl->getReturnType());
+
+    theBuilder.beginFunction(funcType, retType);
+    // TODO: handle function parameters
+    // TODO: handle function body
+    const uint32_t entryLabel = theBuilder.bbCreate();
+    theBuilder.bbReturn(entryLabel);
+    theBuilder.endFunction();
+  }
+
+  uint32_t translateFunctionType(FunctionDecl *decl) {
+    const uint32_t retType = translateType(decl->getReturnType());
+    std::vector<uint32_t> paramTypes;
+    for (auto *param : decl->params()) {
+      paramTypes.push_back(translateType(param->getType()));
+    }
+    return theBuilder.getFunctionType(retType, paramTypes);
+  }
+
+  uint32_t translateType(QualType type) {
+    // In AST, vector types are TypedefType of TemplateSpecializationType,
+    // which is nested deeply. So we do fast track check here.
+    const auto symbol = type.getAsString();
+    if (symbol == "float4") {
+      const uint32_t floatType = theBuilder.getFloatType();
+      return theBuilder.getVec4Type(floatType);
+    } else if (symbol == "float3") {
+      const uint32_t floatType = theBuilder.getFloatType();
+      return theBuilder.getVec3Type(floatType);
+    } else if (symbol == "float2") {
+      const uint32_t floatType = theBuilder.getFloatType();
+      return theBuilder.getVec2Type(floatType);
+    } else if (auto *builtinType = dyn_cast<BuiltinType>(type.getTypePtr())) {
+      switch (builtinType->getKind()) {
+      case BuiltinType::Void:
+        return theBuilder.getVoidType();
+      case BuiltinType::Float:
+        return theBuilder.getFloatType();
+      default:
+        // TODO: handle other primitive types
+        assert(false && "unhandled builtin type");
+        break;
+      }
+    } else {
+      // TODO: handle other types
+      assert(false && "unhandled clang type");
+    }
+    return 0;
   }
 
 private:
-  raw_ostream &OutStream;
-  spirv::SPIRVContext TheContext;
-  spirv::ModuleBuilder Builder;
+  raw_ostream &outStream;
+  spirv::SPIRVContext theContext;
+  spirv::ModuleBuilder theBuilder;
 };
-}
+
+} // namespace
 
 std::unique_ptr<ASTConsumer>
 EmitSPIRVAction::CreateASTConsumer(CompilerInstance &CI, StringRef InFile) {

+ 97 - 46
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -9,87 +9,138 @@
 
 #include "clang/SPIRV/ModuleBuilder.h"
 
+#include "spirv/1.0//spirv.hpp11"
 #include "clang/SPIRV/InstBuilder.h"
 #include "llvm/llvm_assert/assert.h"
-#include "spirv/1.0//spirv.hpp11"
 
 namespace clang {
 namespace spirv {
 
 ModuleBuilder::ModuleBuilder(SPIRVContext *C)
-    : theContext(*C), theModule(), theFunction(llvm::None),
-      theBasicBlock(llvm::None), instBuilder(nullptr) {
+    : theContext(*C), theModule(), theFunction(nullptr), insertPoint(nullptr),
+      instBuilder(nullptr) {
   instBuilder.setConsumer([this](std::vector<uint32_t> &&words) {
     this->constructSite = std::move(words);
   });
 }
 
-ModuleBuilder::Status ModuleBuilder::beginModule() {
-  if (!theModule.isEmpty() || theFunction.hasValue() ||
-      theBasicBlock.hasValue())
-    return Status::ErrNestedModule;
+uint32_t ModuleBuilder::beginFunction(uint32_t funcType, uint32_t returnType) {
+  if (theFunction) {
+    assert(false && "found nested function");
+    return 0;
+  }
 
-  return Status::Success;
-}
+  const uint32_t fId = theContext.takeNextId();
 
-ModuleBuilder::Status ModuleBuilder::endModule() {
-  theModule.setBound(theContext.getNextId());
-  return Status::Success;
+  theFunction = llvm::make_unique<Function>(
+      returnType, fId, spv::FunctionControlMask::MaskNone, funcType);
+
+  return fId;
 }
 
-ModuleBuilder::Status ModuleBuilder::beginFunction(uint32_t funcType,
-                                                   uint32_t returnType) {
-  if (theFunction.hasValue())
-    return Status::ErrNestedFunction;
+bool ModuleBuilder::endFunction() {
+  if (theFunction == nullptr) {
+    assert(false && "no active function");
+    return false;
+  }
 
-  theFunction = llvm::Optional<Function>(
-      Function(returnType, theContext.takeNextId(),
-               spv::FunctionControlMask::MaskNone, funcType));
+  // Move all basic blocks into the current function.
+  // TODO: we should adjust the order the basic blocks according to
+  // SPIR-V validation rules.
+  for (auto &bb : basicBlocks) {
+    theFunction->addBasicBlock(std::move(bb.second));
+  }
+  basicBlocks.clear();
 
-  return Status::Success;
+  theModule.addFunction(std::move(theFunction));
+  theFunction.reset(nullptr);
+
+  insertPoint = nullptr;
+
+  return true;
 }
 
-ModuleBuilder::Status ModuleBuilder::endFunction() {
-  if (theBasicBlock.hasValue())
-    return Status::ErrActiveBasicBlock;
-  if (!theFunction.hasValue())
-    return Status::ErrNoActiveFunction;
+uint32_t ModuleBuilder::bbCreate() {
+  if (theFunction == nullptr) {
+    assert(false && "found detached basic block");
+    return 0;
+  }
+
+  const uint32_t labelId = theContext.takeNextId();
+  basicBlocks[labelId] = llvm::make_unique<BasicBlock>(labelId);
+
+  return labelId;
+}
 
-  theModule.addFunction(std::move(theFunction.getValue()));
-  theFunction.reset();
+bool ModuleBuilder::bbReturn(uint32_t labelId) {
+  auto it = basicBlocks.find(labelId);
+  if (it == basicBlocks.end()) {
+    assert(false && "invalid <label-id>");
+    return false;
+  }
 
-  return Status::Success;
+  instBuilder.opReturn().x();
+  it->second->addInstruction(std::move(constructSite));
+  return true;
 }
 
-ModuleBuilder::Status ModuleBuilder::beginBasicBlock() {
-  if (theBasicBlock.hasValue())
-    return Status::ErrNestedBasicBlock;
-  if (!theFunction.hasValue())
-    return Status::ErrDetachedBasicBlock;
+bool ModuleBuilder::setInsertPoint(uint32_t labelId) {
+  auto it = basicBlocks.find(labelId);
+  if (it == basicBlocks.end()) {
+    assert(false && "invalid <label-id>");
+    return false;
+  }
+  insertPoint = it->second.get();
+  return true;
+}
 
-  theBasicBlock =
-      llvm::Optional<BasicBlock>(BasicBlock(theContext.takeNextId()));
+uint32_t ModuleBuilder::getVoidType() {
+  const Type *type = Type::getVoid(theContext);
+  const uint32_t typeId = theContext.getResultIdForType(type);
+  theModule.addType(type, typeId);
+  return typeId;
+}
 
-  return Status::Success;
+uint32_t ModuleBuilder::getFloatType() {
+  const Type *type = Type::getFloat32(theContext);
+  const uint32_t typeId = theContext.getResultIdForType(type);
+  theModule.addType(type, typeId);
+  return typeId;
 }
 
-ModuleBuilder::Status ModuleBuilder::endBasicBlockWithReturn() {
-  if (!theBasicBlock.hasValue())
-    return Status::ErrNoActiveBasicBlock;
+uint32_t ModuleBuilder::getVec2Type(uint32_t elemType) {
+  const Type *type = Type::getVec2(theContext, elemType);
+  const uint32_t typeId = theContext.getResultIdForType(type);
+  theModule.addType(type, typeId);
+  return typeId;
+}
 
-  instBuilder.opReturn().x();
-  theBasicBlock.getValue().addInstruction(std::move(constructSite));
+uint32_t ModuleBuilder::getVec3Type(uint32_t elemType) {
+  const Type *type = Type::getVec3(theContext, elemType);
+  const uint32_t typeId = theContext.getResultIdForType(type);
+  theModule.addType(type, typeId);
+  return typeId;
+}
 
-  return endBasicBlock();
+uint32_t ModuleBuilder::getVec4Type(uint32_t elemType) {
+  const Type *type = Type::getVec4(theContext, elemType);
+  const uint32_t typeId = theContext.getResultIdForType(type);
+  theModule.addType(type, typeId);
+  return typeId;
 }
 
-ModuleBuilder::Status ModuleBuilder::endBasicBlock() {
-  theFunction.getValue().addBasicBlock(std::move(theBasicBlock.getValue()));
-  theBasicBlock.reset();
-  return Status::Success;
+uint32_t
+ModuleBuilder::getFunctionType(uint32_t returnType,
+                               const std::vector<uint32_t> &paramTypes) {
+  const Type *type = Type::getFunction(theContext, returnType, paramTypes);
+  const uint32_t typeId = theContext.getResultIdForType(type);
+  theModule.addType(type, typeId);
+  return typeId;
 }
 
 std::vector<uint32_t> ModuleBuilder::takeModule() {
+  theModule.setBound(theContext.getNextId());
+
   std::vector<uint32_t> binary;
   auto ib = InstBuilder([&binary](std::vector<uint32_t> &&words) {
     binary.insert(binary.end(), words.begin(), words.end());

+ 16 - 9
tools/clang/lib/SPIRV/Structure.cpp

@@ -94,7 +94,7 @@ void Function::take(InstBuilder *builder) {
     builder->opFunctionParameter(param.first, param.second).x();
   }
   for (auto &block : blocks) {
-    block.take(builder);
+    block->take(builder);
   }
   builder->opFunctionEnd().x();
   clear();
@@ -142,7 +142,7 @@ void SPIRVModule::take(InstBuilder *builder) {
   }
 
   for (auto &inst : extInstSets) {
-    builder->opExtInstImport(inst.first, inst.second).x();
+    builder->opExtInstImport(inst.resultId, inst.setName).x();
   }
 
   if (addressingModel.hasValue() && memoryModel.hasValue()) {
@@ -163,24 +163,31 @@ void SPIRVModule::take(InstBuilder *builder) {
   for (auto &inst : debugNames) {
     if (inst.memberIndex.hasValue()) {
       builder
-          ->opMemberName(inst.targetId, inst.memberIndex.getValue(),
-                         std::move(inst.name))
+          ->opMemberName(inst.targetId, *inst.memberIndex, std::move(inst.name))
           .x();
     } else {
       builder->opName(inst.targetId, std::move(inst.name)).x();
     }
   }
 
-  for (auto &inst : decorations) {
-    consumer(std::move(inst));
+  for (const auto &d : decorations) {
+    consumer(d.decoration.withTargetId(d.targetId));
   }
 
-  for (auto &inst : typesValues) {
-    consumer(std::move(inst));
+  // TODO: handle the interdependency between types and constants
+
+  for (const auto &t : types) {
+    consumer(t.first->withResultId(t.second));
   }
 
+  for (auto &c : constants) {
+    consumer(std::move(c.constant));
+  }
+
+  // TODO: global variables
+
   for (uint32_t i = 0; i < functions.size(); ++i) {
-    functions[i].take(builder);
+    functions[i]->take(builder);
   }
 
   clear();

+ 15 - 1
tools/clang/lib/SPIRV/Type.cpp

@@ -143,7 +143,7 @@ const Type *Type::getPointer(SPIRVContext &context,
   return getUniqueType(context, t);
 }
 const Type *Type::getFunction(SPIRVContext &context, uint32_t return_type,
-                              std::initializer_list<uint32_t> params,
+                              const std::vector<uint32_t> &params,
                               DecorationSet d) {
   std::vector<uint32_t> args = {return_type};
   args.insert(args.end(), params.begin(), params.end());
@@ -201,5 +201,19 @@ bool Type::hasDecoration(const Decoration *d) const {
   return decorations.find(d) != decorations.end();
 }
 
+std::vector<uint32_t> Type::withResultId(uint32_t resultId) const {
+  std::vector<uint32_t> words;
+
+  // TODO: we are essentially duplicate the work InstBuilder is responsible for.
+  // Should figure out a way to unify them.
+  words.reserve(2 + args.size());
+  words.push_back(static_cast<uint32_t>(opcode));
+  words.push_back(resultId);
+  words.insert(words.end(), args.begin(), args.end());
+  words.front() |= static_cast<uint32_t>(words.size()) << 16;
+
+  return words;
+}
+
 } // end namespace spirv
 } // end namespace clang

+ 0 - 14
tools/clang/test/CodeGenSPIRV/basic.hlsl2spv

@@ -1,14 +0,0 @@
-// Comments 1
-// Comments 2
-// Run: %dxc -T ps_6_0 -E main
-void main()
-{
-
-}
-
-// CHECK-WHOLE-SPIR-V:
-// ; SPIR-V
-// ; Version: 1.0
-// ; Generator: Google spiregg; 0
-// ; Bound: 1
-// ; Schema: 0

+ 26 - 0
tools/clang/test/CodeGenSPIRV/empty-void-main.hlsl2spv

@@ -0,0 +1,26 @@
+// Run: %dxc -T ps_6_0 -E main
+void main()
+{
+
+}
+
+
+// TODO:
+// OpEntryPoint Fragment %main "main"
+// OpExecutionMode %main OriginUpperLeft
+
+
+// CHECK-WHOLE-SPIR-V:
+// ; SPIR-V
+// ; Version: 1.0
+// ; Generator: Google spiregg; 0
+// ; Bound: 5
+// ; Schema: 0
+// OpCapability Shader
+// OpMemoryModel Logical GLSL450
+// %void = OpTypeVoid
+// %2 = OpTypeFunction %void
+// %3 = OpFunction %void None %2
+// %4 = OpLabel
+// OpReturn
+// OpFunctionEnd

+ 30 - 0
tools/clang/test/CodeGenSPIRV/passthru-ps.hlsl2spv

@@ -0,0 +1,30 @@
+// Run: %dxc -T ps_6_0 -E main
+float4 main(float4 input: COLOR): SV_TARGET
+{
+    return input;
+}
+
+
+// TODO:
+// OpEntryPoint Fragment %main "main"
+// OpExecutionMode %main OriginUpperLeft
+// Semantics
+// Function parameter
+// Function return value
+
+
+// CHECK-WHOLE-SPIR-V:
+// ; SPIR-V
+// ; Version: 1.0
+// ; Generator: Google spiregg; 0
+// ; Bound: 6
+// ; Schema: 0
+// OpCapability Shader
+// OpMemoryModel Logical GLSL450
+// %float = OpTypeFloat 32
+// %v4float = OpTypeVector %float 4
+// %3 = OpTypeFunction %v4float %v4float
+// %4 = OpFunction %v4float None %3
+// %5 = OpLabel
+// OpReturn
+// OpFunctionEnd

+ 10 - 7
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -7,16 +7,19 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include <fstream>
-
 #include "WholeFileCheck.h"
-#include "gtest/gtest.h"
 
-TEST_F(WholeFileTest, BringUp) {
+TEST_F(WholeFileTest, EmptyVoidMain) {
   // Ideally all generated SPIR-V must be valid, but this currently fails with
   // this error message: "No OpEntryPoint instruction was found...".
   // TODO: change this test such that it does run validation.
-  bool success = runWholeFileTest("basic.hlsl2spv", /*generateHeader*/ true,
-                                  /*runValidation*/ false);
-  EXPECT_TRUE(success);
+  runWholeFileTest("empty-void-main.hlsl2spv",
+                   /*generateHeader*/ true,
+                   /*runValidation*/ false);
+}
+
+TEST_F(WholeFileTest, PassThruPixelShader) {
+  runWholeFileTest("passthru-ps.hlsl2spv",
+                   /*generateHeader*/ true,
+                   /*runValidation*/ false);
 }

+ 21 - 0
tools/clang/unittests/SPIRV/DecorationTest.cpp

@@ -7,6 +7,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "SPIRVTestUtils.h"
 #include "gmock/gmock.h"
 #include "clang/SPIRV/Decoration.h"
 #include "clang/SPIRV/SPIRVContext.h"
@@ -17,6 +18,7 @@ using namespace clang::spirv;
 
 namespace {
 using ::testing::ElementsAre;
+using ::testing::ContainerEq;
 
 TEST(Decoration, SameDecorationWoParameterShouldHaveSameAddress) {
   SPIRVContext ctx;
@@ -560,4 +562,23 @@ TEST(Decoration, ViewportRelativeNV) {
   EXPECT_TRUE(dec->getArgs().empty());
 }
 
+TEST(Decoration, BlockWithTargetId) {
+  SPIRVContext ctx;
+  const Decoration *d = Decoration::getBlock(ctx);
+  const auto result = d->withTargetId(1);
+  const auto expected = constructInst(
+      spv::Op::OpDecorate, {1, static_cast<uint32_t>(spv::Decoration::Block)});
+  EXPECT_THAT(result, ContainerEq(expected));
+}
+
+TEST(Decoration, RowMajorWithTargetId) {
+  SPIRVContext ctx;
+  const Decoration *d = Decoration::getRowMajor(ctx, 3);
+  const auto result = d->withTargetId(2);
+  const auto expected =
+      constructInst(spv::Op::OpMemberDecorate,
+                    {2, 3, static_cast<uint32_t>(spv::Decoration::RowMajor)});
+  EXPECT_THAT(result, ContainerEq(expected));
+}
+
 } // anonymous namespace

+ 14 - 94
tools/clang/unittests/SPIRV/ModuleBuilderTest.cpp

@@ -7,8 +7,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "clang/SPIRV/ModuleBuilder.h"
 #include "spirv/1.0/spirv.hpp11"
+#include "clang/SPIRV/ModuleBuilder.h"
 
 #include "SPIRVTestUtils.h"
 
@@ -19,36 +19,24 @@ using namespace clang::spirv;
 using ::testing::ContainerEq;
 using ::testing::ElementsAre;
 
-void expectBuildSuccess(ModuleBuilder::Status status) {
-  EXPECT_EQ(ModuleBuilder::Status::Success, status);
-}
-
-TEST(ModuleBuilder, BeginAndThenEndModuleCreatesHeader) {
+TEST(ModuleBuilder, TakeModuleDirectlyCreatesHeader) {
   SPIRVContext context;
   ModuleBuilder builder(&context);
 
-  expectBuildSuccess(builder.beginModule());
-  expectBuildSuccess(builder.endModule());
-  std::vector<uint32_t> spvModule = builder.takeModule();
-
-  // At the very least, running BeginModule() and EndModule() should
-  // create the SPIR-V Header. The header is exactly 5 words long.
-  EXPECT_EQ(spvModule.size(), 5u);
-  EXPECT_THAT(spvModule,
+  EXPECT_THAT(builder.takeModule(),
               ElementsAre(spv::MagicNumber, spv::Version, 14u << 16, 1u, 0u));
 }
 
-TEST(ModuleBuilder, BeginEndFunctionCreatesFunction) {
+TEST(ModuleBuilder, CreateFunction) {
   SPIRVContext context;
   ModuleBuilder builder(&context);
 
-  expectBuildSuccess(builder.beginModule());
   const auto rType = context.takeNextId();
   const auto fType = context.takeNextId();
   const auto fId = context.getNextId();
-  expectBuildSuccess(builder.beginFunction(fType, rType));
-  expectBuildSuccess(builder.endFunction());
-  expectBuildSuccess(builder.endModule());
+  const auto resultId = builder.beginFunction(fType, rType);
+  EXPECT_EQ(fId, resultId);
+  EXPECT_TRUE(builder.endFunction());
   const auto result = builder.takeModule();
 
   auto expected = getModuleHeader(context.getNextId());
@@ -58,20 +46,20 @@ TEST(ModuleBuilder, BeginEndFunctionCreatesFunction) {
   EXPECT_THAT(result, ContainerEq(expected));
 }
 
-TEST(ModuleBuilder, BeginEndBasicBlockCreatesBasicBlock) {
+TEST(ModuleBuilder, CreateBasicBlock) {
   SPIRVContext context;
   ModuleBuilder builder(&context);
 
-  expectBuildSuccess(builder.beginModule());
   const auto rType = context.takeNextId();
   const auto fType = context.takeNextId();
   const auto fId = context.getNextId();
-  expectBuildSuccess(builder.beginFunction(fType, rType));
+  EXPECT_NE(0, builder.beginFunction(fType, rType));
   const auto labelId = context.getNextId();
-  expectBuildSuccess(builder.beginBasicBlock());
-  expectBuildSuccess(builder.endBasicBlockWithReturn());
-  expectBuildSuccess(builder.endFunction());
-  expectBuildSuccess(builder.endModule());
+  const auto resultId = builder.bbCreate();
+  EXPECT_EQ(labelId, resultId);
+  EXPECT_TRUE(builder.bbReturn(resultId));
+  EXPECT_TRUE(builder.endFunction());
+
   const auto result = builder.takeModule();
 
   auto expected = getModuleHeader(context.getNextId());
@@ -84,72 +72,4 @@ TEST(ModuleBuilder, BeginEndBasicBlockCreatesBasicBlock) {
   EXPECT_THAT(result, ContainerEq(expected));
 }
 
-TEST(ModuleBuilder, NestedModuleResultsInError) {
-  SPIRVContext context;
-  ModuleBuilder builder(&context);
-
-  expectBuildSuccess(builder.beginModule());
-  expectBuildSuccess(builder.beginModule());
-  expectBuildSuccess(builder.beginFunction(1, 2));
-  EXPECT_EQ(ModuleBuilder::Status::ErrNestedModule, builder.beginModule());
-}
-
-TEST(ModuleBuilder, NestedFunctionResultsInError) {
-  SPIRVContext context;
-  ModuleBuilder builder(&context);
-
-  expectBuildSuccess(builder.beginModule());
-  expectBuildSuccess(builder.beginFunction(1, 2));
-  EXPECT_EQ(ModuleBuilder::Status::ErrNestedFunction,
-            builder.beginFunction(3, 4));
-}
-
-TEST(ModuleBuilder, NestedBasicBlockResultsInError) {
-  SPIRVContext context;
-  ModuleBuilder builder(&context);
-
-  expectBuildSuccess(builder.beginModule());
-  expectBuildSuccess(builder.beginFunction(1, 2));
-  expectBuildSuccess(builder.beginBasicBlock());
-  EXPECT_EQ(ModuleBuilder::Status::ErrNestedBasicBlock,
-            builder.beginBasicBlock());
-}
-
-TEST(ModuleBuilder, BasicBlockWoFunctionResultsInError) {
-  SPIRVContext context;
-  ModuleBuilder builder(&context);
-
-  expectBuildSuccess(builder.beginModule());
-  EXPECT_EQ(ModuleBuilder::Status::ErrDetachedBasicBlock,
-            builder.beginBasicBlock());
-}
-
-TEST(ModuleBuilder, EndFunctionWoBeginFunctionResultsInError) {
-  SPIRVContext context;
-  ModuleBuilder builder(&context);
-
-  expectBuildSuccess(builder.beginModule());
-  EXPECT_EQ(ModuleBuilder::Status::ErrNoActiveFunction, builder.endFunction());
-}
-
-TEST(ModuleBuilder, EndFunctionWActiveBasicBlockResultsInError) {
-  SPIRVContext context;
-  ModuleBuilder builder(&context);
-
-  expectBuildSuccess(builder.beginModule());
-  expectBuildSuccess(builder.beginFunction(1, 2));
-  expectBuildSuccess(builder.beginBasicBlock());
-  EXPECT_EQ(ModuleBuilder::Status::ErrActiveBasicBlock, builder.endFunction());
-}
-
-TEST(ModuleBuilder, ReturnWActiveBasicBlockResultsInError) {
-  SPIRVContext context;
-  ModuleBuilder builder(&context);
-
-  expectBuildSuccess(builder.beginModule());
-  expectBuildSuccess(builder.beginFunction(1, 2));
-  EXPECT_EQ(ModuleBuilder::Status::ErrNoActiveBasicBlock,
-            builder.endBasicBlockWithReturn());
-}
-
 } // anonymous namespace

+ 70 - 42
tools/clang/unittests/SPIRV/StructureTest.cpp

@@ -10,6 +10,7 @@
 #include "clang/SPIRV/Structure.h"
 
 #include "SPIRVTestUtils.h"
+#include "clang/SPIRV/SPIRVContext.h"
 
 namespace {
 
@@ -55,8 +56,8 @@ TEST(Structure, TakeFunctionHaveAllContents) {
   auto f = Function(1, 2, spv::FunctionControlMask::Inline, 3);
   f.addParameter(1, 42);
 
-  auto bb = BasicBlock(10);
-  bb.addInstruction(constructInst(spv::Op::OpReturn, {}));
+  auto bb = llvm::make_unique<BasicBlock>(10);
+  bb->addInstruction(constructInst(spv::Op::OpReturn, {}));
   f.addBasicBlock(std::move(bb));
 
   std::vector<uint32_t> result;
@@ -96,9 +97,11 @@ TEST(Structure, AfterClearModuleIsEmpty) {
 }
 
 TEST(Structure, TakeModuleHaveAllContents) {
+  SPIRVContext context;
   auto m = SPIRVModule();
-  std::vector<uint32_t> expected{spv::MagicNumber, spv::Version,
-                                 /* generator */ 14u << 16, /* bound */ 6, 0};
+
+  // Will fix up the bound later.
+  std::vector<uint32_t> expected = getModuleHeader(0);
 
   m.addCapability(spv::Capability::Shader);
   appendVector(&expected,
@@ -109,9 +112,11 @@ TEST(Structure, TakeModuleHaveAllContents) {
   const uint32_t extWord = 'e' | ('x' << 8) | ('t' << 16);
   appendVector(&expected, constructInst(spv::Op::OpExtension, {extWord}));
 
-  m.addExtInstSet(5, "gl");
+  const uint32_t extInstSetId = context.takeNextId();
+  m.addExtInstSet(extInstSetId, "gl");
   const uint32_t glWord = 'g' | ('l' << 8);
-  appendVector(&expected, constructInst(spv::Op::OpExtInstImport, {5, glWord}));
+  appendVector(&expected,
+               constructInst(spv::Op::OpExtInstImport, {extInstSetId, glWord}));
 
   m.setAddressingModel(spv::AddressingModel::Logical);
   m.setMemoryModel(spv::MemoryModel::GLSL450);
@@ -121,57 +126,80 @@ TEST(Structure, TakeModuleHaveAllContents) {
                     {static_cast<uint32_t>(spv::AddressingModel::Logical),
                      static_cast<uint32_t>(spv::MemoryModel::GLSL450)}));
 
-  m.addEntryPoint(spv::ExecutionModel::Fragment, 2, "main", {42});
+  const uint32_t entryPointId = context.takeNextId();
+  m.addEntryPoint(spv::ExecutionModel::Fragment, entryPointId, "main", {42});
   const uint32_t mainWord = 'm' | ('a' << 8) | ('i' << 16) | ('n' << 24);
   appendVector(
       &expected,
       constructInst(spv::Op::OpEntryPoint,
-                    {static_cast<uint32_t>(spv::ExecutionModel::Fragment), 2,
-                     mainWord, /* addtional null in name */ 0, 42}));
+                    {static_cast<uint32_t>(spv::ExecutionModel::Fragment),
+                     entryPointId, mainWord, /* addtional null in name */ 0,
+                     42}));
 
   m.addExecutionMode(constructInst(
       spv::Op::OpExecutionMode,
-      {2, static_cast<uint32_t>(spv::ExecutionMode::OriginUpperLeft)}));
-  appendVector(&expected,
-               constructInst(spv::Op::OpExecutionMode,
-                             {2, static_cast<uint32_t>(
-                                     spv::ExecutionMode::OriginUpperLeft)}));
-
-  // TODO: other debug instructions
-
-  m.addDebugName(2, llvm::None, "main");
-  appendVector(&expected,
-               constructInst(spv::Op::OpName,
-                             {2, mainWord, /* additional null in name */ 0}));
-
-  m.addDecoration(constructInst(
-      spv::Op::OpDecorate,
-      {2, static_cast<uint32_t>(spv::Decoration::RelaxedPrecision)}));
-  appendVector(&expected,
-               constructInst(spv::Op::OpDecorate,
-                             {2, static_cast<uint32_t>(
-                                     spv::Decoration::RelaxedPrecision)}));
+      {entryPointId,
+       static_cast<uint32_t>(spv::ExecutionMode::OriginUpperLeft)}));
+  appendVector(
+      &expected,
+      constructInst(spv::Op::OpExecutionMode,
+                    {entryPointId, static_cast<uint32_t>(
+                                       spv::ExecutionMode::OriginUpperLeft)}));
 
-  m.addType(constructInst(spv::Op::OpTypeVoid, {1}));
-  appendVector(&expected, constructInst(spv::Op::OpTypeVoid, {1}));
+  // TODO: source code debug information
 
-  m.addType(constructInst(spv::Op::OpTypeFunction, {3, 1, 1}));
-  appendVector(&expected, constructInst(spv::Op::OpTypeFunction, {3, 1, 1}));
+  m.addDebugName(entryPointId, llvm::None, "main");
+  appendVector(&expected, constructInst(spv::Op::OpName,
+                                        {entryPointId, mainWord,
+                                         /* additional null in name */ 0}));
 
-  // TODO: constant
-  // TODO: variable
+  m.addDecoration(*Decoration::getRelaxedPrecision(context), entryPointId);
 
-  auto f = Function(1, 2, spv::FunctionControlMask::MaskNone, 3);
-  auto bb = BasicBlock(4);
-  bb.addInstruction(constructInst(spv::Op::OpReturn, {}));
-  f.addBasicBlock(std::move(bb));
+  appendVector(&expected,
+               constructInst(
+                   spv::Op::OpDecorate,
+                   {entryPointId,
+                    static_cast<uint32_t>(spv::Decoration::RelaxedPrecision)}));
+
+  const auto *voidType = Type::getVoid(context);
+  const uint32_t voidId = context.getResultIdForType(voidType);
+  m.addType(voidType, voidId);
+  appendVector(&expected, constructInst(spv::Op::OpTypeVoid, {voidId}));
+
+  const auto *funcType = Type::getFunction(context, voidId, {voidId});
+  const uint32_t funcTypeId = context.getResultIdForType(funcType);
+  m.addType(funcType, funcTypeId);
+  appendVector(&expected, constructInst(spv::Op::OpTypeFunction,
+                                        {funcTypeId, voidId, voidId}));
+
+  const auto *i32Type = Type::getInt32(context);
+  const uint32_t i32Id = context.getResultIdForType(i32Type);
+  m.addType(i32Type, i32Id);
+  appendVector(&expected, constructInst(spv::Op::OpTypeInt, {i32Id, 32, 1}));
+
+  const uint32_t constantId = context.takeNextId();
+  m.addConstant(*i32Type,
+                constructInst(spv::Op::OpConstant, {i32Id, constantId, 42}));
+  appendVector(&expected,
+               constructInst(spv::Op::OpConstant, {i32Id, constantId, 42}));
+  // TODO: global variable
+
+  const uint32_t funcId = context.takeNextId();
+  auto f = llvm::make_unique<Function>(
+      voidId, funcId, spv::FunctionControlMask::MaskNone, funcTypeId);
+  const uint32_t bbId = context.takeNextId();
+  auto bb = llvm::make_unique<BasicBlock>(bbId);
+  bb->addInstruction(constructInst(spv::Op::OpReturn, {}));
+  f->addBasicBlock(std::move(bb));
   m.addFunction(std::move(f));
-  appendVector(&expected, constructInst(spv::Op::OpFunction, {1, 2, 0, 3}));
-  appendVector(&expected, constructInst(spv::Op::OpLabel, {4}));
+  appendVector(&expected, constructInst(spv::Op::OpFunction,
+                                        {voidId, funcId, 0, funcTypeId}));
+  appendVector(&expected, constructInst(spv::Op::OpLabel, {bbId}));
   appendVector(&expected, constructInst(spv::Op::OpReturn, {}));
   appendVector(&expected, constructInst(spv::Op::OpFunctionEnd, {}));
 
-  m.setBound(6);
+  m.setBound(context.getNextId());
+  expected[3] = context.getNextId();
 
   std::vector<uint32_t> result;
   auto ib = constructInstBuilder(result);

+ 17 - 0
tools/clang/unittests/SPIRV/TypeTest.cpp

@@ -7,6 +7,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "SPIRVTestUtils.h"
 #include "gmock/gmock.h"
 #include "clang/SPIRV/SPIRVContext.h"
 #include "clang/SPIRV/String.h"
@@ -16,6 +17,7 @@
 using namespace clang::spirv;
 
 namespace {
+using ::testing::ContainerEq;
 using ::testing::ElementsAre;
 
 TEST(Type, SameTypeWoParameterShouldHaveSameAddress) {
@@ -544,4 +546,19 @@ TEST(Type, DecoratedForwardPointer) {
   EXPECT_THAT(t->getDecorations(), ElementsAre(d));
 }
 
+TEST(Type, BoolWithResultId) {
+  SPIRVContext ctx;
+  const Type *t = Type::getBool(ctx);
+  const auto words = t->withResultId(1);
+  EXPECT_THAT(words, ContainerEq(constructInst(spv::Op::OpTypeBool, {1})));
+}
+
+TEST(Type, IntWithResultId) {
+  SPIRVContext ctx;
+  const Type *t = Type::getInt32(ctx);
+  const auto words = t->withResultId(42);
+  EXPECT_THAT(words,
+              ContainerEq(constructInst(spv::Op::OpTypeInt, {42, 32, 1})));
+}
+
 } // anonymous namespace

+ 6 - 14
tools/clang/unittests/SPIRV/WholeFileCheck.cpp

@@ -180,10 +180,6 @@ void WholeFileTest::convertIDxcBlobToUint32(const CComPtr<IDxcBlob> &blob) {
   memcpy(generatedBinary.data(), binaryStr.data(), binaryStr.size());
 }
 
-bool WholeFileTest::compareExpectedSpirvAndGeneratedSpirv() {
-  return generatedSpirvAsm == expectedSpirvAsm;
-}
-
 std::string
 WholeFileTest::getAbsPathOfInputDataFile(const std::string &filename) {
   std::string path = clang::spirv::testOptions::inputDataDir;
@@ -202,28 +198,24 @@ WholeFileTest::getAbsPathOfInputDataFile(const std::string &filename) {
   return path;
 }
 
-bool WholeFileTest::runWholeFileTest(std::string filename, bool generateHeader,
+void WholeFileTest::runWholeFileTest(std::string filename, bool generateHeader,
                                      bool runSpirvValidation) {
   inputFilePath = getAbsPathOfInputDataFile(filename);
 
-  bool success = true;
-
   // Parse the input file.
-  success = success && parseInputFile();
+  ASSERT_TRUE(parseInputFile());
 
   // Feed the HLSL source into the Compiler.
-  success = success && runCompilerWithSpirvGeneration();
+  ASSERT_TRUE(runCompilerWithSpirvGeneration());
 
   // Disassemble the generated SPIR-V binary.
-  success = success && disassembleSpirvBinary(generateHeader);
+  ASSERT_TRUE(disassembleSpirvBinary(generateHeader));
 
   // Run SPIR-V validation if requested.
   if (runSpirvValidation) {
-    success = success && validateSpirvBinary();
+    ASSERT_TRUE(validateSpirvBinary());
   }
 
   // Compare the expected and the generted SPIR-V code.
-  success = success && compareExpectedSpirvAndGeneratedSpirv();
-
-  return success;
+  EXPECT_EQ(expectedSpirvAsm, generatedSpirvAsm);
 }

+ 1 - 5
tools/clang/unittests/SPIRV/WholeFileCheck.h

@@ -55,7 +55,7 @@ public:
   /// It is also important that all generated SPIR-V code is valid. Users of
   /// WholeFileTest may choose not to run the SPIR-V Validator (for cases where
   /// a certain feature has not been added to the Validator yet).
-  bool runWholeFileTest(std::string path, bool generateHeader = false,
+  void runWholeFileTest(std::string path, bool generateHeader = false,
                         bool runSpirvValidation = true);
 
 private:
@@ -77,10 +77,6 @@ private:
   /// Returns true if validation is successful; false otherwise.
   bool validateSpirvBinary();
 
-  /// \brief Compares the expected and the generated SPIR-V code.
-  /// Returns true if they match, and false otherwise.
-  bool compareExpectedSpirvAndGeneratedSpirv();
-
   /// \brief Parses the Target Profile and Entry Point from the Run command
   bool processRunCommandArgs(const std::string &runCommandLine);