浏览代码

[spirv] Improve SPIR-V structures and tests for them (#477)

* Changed Instruction to a class
* Created SimpleInstBuilder for building instructions in tests
* Various cosmetic changes
Lei Zhang 8 年之前
父节点
当前提交
6503469808

+ 85 - 25
tools/clang/include/clang/SPIRV/Structure.h

@@ -35,30 +35,66 @@
 namespace clang {
 namespace spirv {
 
-// TODO: do some statistics and switch to use SmallVector here if helps.
+// === Instruction definition ===
+
 /// \brief The class representing a SPIR-V instruction.
-using Instruction = std::vector<uint32_t>;
+class Instruction {
+public:
+  /// Constructs an instruction from the given underlying SPIR-V binary words.
+  inline Instruction(std::vector<uint32_t> &&);
+
+  // Copy constructor/assignment
+  Instruction(const Instruction &) = default;
+  Instruction &operator=(const Instruction &) = default;
+
+  // Move constructor/assignment
+  Instruction(Instruction &&) = default;
+  Instruction &operator=(Instruction &&) = default;
+
+  /// Returns true if this instruction is empty, which contains no underlying
+  /// SPIR-V binary words.
+  inline bool isEmpty() const;
+
+  /// Returns the opcode for this instruction. Returns spv::Op::Max if this
+  /// instruction is empty.
+  spv::Op getOpcode() const;
+
+  /// Returns the underlying SPIR-v binary words for this instruction.
+  /// This instruction will be in an empty state after this call.
+  inline std::vector<uint32_t> take();
+
+  /// Returns true if this instruction is a termination instruction.
+  ///
+  /// See "2.2.4. Control Flow" in the SPIR-V spec for the defintion of
+  /// termination instructions.
+  bool isTerminator() const;
+
+private:
+  // TODO: do some statistics and switch to use SmallVector here if helps.
+  std::vector<uint32_t> words; ///< Underlying SPIR-V words
+};
+
+// === Basic block definition ===
 
 /// \brief The class representing a SPIR-V basic block.
 class BasicBlock {
 public:
-  /// \brief Default constructs an empty basic block.
-  inline BasicBlock();
-  /// \brief Constructs a basic block with the given label id.
+  /// \brief Constructs a basic block with the given <label-id>.
   inline explicit BasicBlock(uint32_t labelId);
 
-  // Disable copy constructor/assignment until we find they are truly useful.
+  // Disable copy constructor/assignment
   BasicBlock(const BasicBlock &) = delete;
   BasicBlock &operator=(const BasicBlock &) = delete;
-  // Allow move constructor/assignment since they are efficient.
+
+  // Move constructor/assignment
   BasicBlock(BasicBlock &&that);
   BasicBlock &operator=(BasicBlock &&that);
 
   /// \brief Returns true if this basic block is empty, which has no <label-id>
   /// assigned and no instructions.
   inline bool isEmpty() const;
-  /// \brief Clears all instructions in this basic block and turns this basic
-  /// block into an empty basic block.
+  /// \brief Clears everything in this basic block and turns it into an
+  /// empty basic block.
   inline void clear();
 
   /// \brief Serializes this basic block and feeds it to the comsumer in the
@@ -80,20 +116,20 @@ private:
   std::deque<Instruction> instructions;
 };
 
+// === Function definition ===
+
 /// \brief The class representing a SPIR-V function.
 class Function {
 public:
-  /// \brief Default constructs an empty SPIR-V function.
-  inline Function();
   /// \brief Constructs a SPIR-V function with the given parameters.
   inline Function(uint32_t resultType, uint32_t resultId,
                   spv::FunctionControlMask control, uint32_t functionType);
 
-  // Disable copy constructor/assignment until we find they are truly useful.
+  // Disable copy constructor/assignment
   Function(const Function &) = delete;
   Function &operator=(const Function &) = delete;
 
-  // Allow move constructor/assignment since they are efficient.
+  // Move constructor/assignment
   Function(Function &&that);
   Function &operator=(Function &&that);
 
@@ -131,6 +167,8 @@ private:
   std::vector<std::unique_ptr<BasicBlock>> blocks;
 };
 
+// === Module components defintion ====
+
 /// \brief The struct representing a SPIR-V module header.
 struct Header {
   /// \brief Default constructs a SPIR-V module header with id bound 0.
@@ -191,17 +229,19 @@ struct TypeIdPair {
   const uint32_t resultId;
 };
 
+// === Module defintion ====
+
 /// \brief The class representing a SPIR-V module.
 class SPIRVModule {
 public:
   /// \brief Default constructs an empty SPIR-V module.
   inline SPIRVModule();
 
-  // Disable copy constructor/assignment until we find they are truly useful.
+  // Disable copy constructor/assignment
   SPIRVModule(const SPIRVModule &) = delete;
   SPIRVModule &operator=(const SPIRVModule &) = delete;
 
-  // Allow move constructor/assignment since they are efficient.
+  // Move constructor/assignment
   SPIRVModule(SPIRVModule &&that) = default;
   SPIRVModule &operator=(SPIRVModule &&that) = default;
 
@@ -248,14 +288,14 @@ private:
   /// If found, (a) defines the constant by passing it to the consumer in the
   /// given InstBuilder. (b) Removes the constant from the list of constants
   /// in this object.
-  void takeConstantForArrayType(const Type *arrType, InstBuilder *ib);
+  void takeConstantForArrayType(const Type &arrType, InstBuilder *ib);
 
 private:
   Header header; ///< SPIR-V module header.
   std::vector<spv::Capability> capabilities;
   std::vector<std::string> extensions;
   std::vector<ExtInstSet> extInstSets;
-  // addressing and memory model must exist for a valid SPIR-V module.
+  // Addressing and memory model must exist for a valid SPIR-V module.
   // We make them optional here just to provide extra flexibility of
   // the representation.
   llvm::Optional<spv::AddressingModel> addressingModel;
@@ -276,16 +316,22 @@ private:
   std::vector<std::unique_ptr<Function>> functions;
 };
 
-BasicBlock::BasicBlock() : labelId(0) {}
+// === Instruction inline implementations ===
+
+Instruction::Instruction(std::vector<uint32_t> &&data)
+    : words(std::move(data)) {}
+
+bool Instruction::isEmpty() const { return words.empty(); }
+
+std::vector<uint32_t> Instruction::take() { return std::move(words); }
+
+// === Basic block inline implementations ===
+
 BasicBlock::BasicBlock(uint32_t id) : labelId(id) {}
 
 bool BasicBlock::isEmpty() const {
   return labelId == 0 && instructions.empty();
 }
-void BasicBlock::clear() {
-  labelId = 0;
-  instructions.clear();
-}
 
 void BasicBlock::appendInstruction(Instruction &&inst) {
   instructions.push_back(std::move(inst));
@@ -295,9 +341,7 @@ void BasicBlock::prependInstruction(Instruction &&inst) {
   instructions.push_front(std::move(inst));
 }
 
-Function::Function()
-    : resultType(0), resultId(0),
-      funcControl(spv::FunctionControlMask::MaskNone), funcType(0) {}
+// === Function inline implementations ===
 
 Function::Function(uint32_t rType, uint32_t rId,
                    spv::FunctionControlMask control, uint32_t fType)
@@ -317,6 +361,8 @@ void Function::addBasicBlock(std::unique_ptr<BasicBlock> block) {
   blocks.push_back(std::move(block));
 }
 
+// === Module components inline implementations ===
+
 ExtInstSet::ExtInstSet(uint32_t id, std::string name)
     : resultId(id), setName(name) {}
 
@@ -334,6 +380,8 @@ DecorationIdPair::DecorationIdPair(const Decoration &decor, uint32_t id)
 
 TypeIdPair::TypeIdPair(const Type &ty, uint32_t id) : type(ty), resultId(id) {}
 
+// === Module inline implementations ===
+
 SPIRVModule::SPIRVModule()
     : addressingModel(llvm::None), memoryModel(llvm::None) {}
 
@@ -342,45 +390,57 @@ void SPIRVModule::setBound(uint32_t newBound) { header.bound = newBound; }
 void SPIRVModule::addCapability(spv::Capability cap) {
   capabilities.push_back(cap);
 }
+
 void SPIRVModule::addExtension(std::string ext) {
   extensions.push_back(std::move(ext));
 }
+
 void SPIRVModule::addExtInstSet(uint32_t setId, std::string extInstSet) {
   extInstSets.emplace_back(setId, extInstSet);
 }
+
 void SPIRVModule::setAddressingModel(spv::AddressingModel am) {
   addressingModel = llvm::Optional<spv::AddressingModel>(am);
 }
+
 void SPIRVModule::setMemoryModel(spv::MemoryModel mm) {
   memoryModel = llvm::Optional<spv::MemoryModel>(mm);
 }
+
 void SPIRVModule::addEntryPoint(spv::ExecutionModel em, uint32_t targetId,
                                 std::string name,
                                 llvm::ArrayRef<uint32_t> interfaces) {
   entryPoints.emplace_back(em, targetId, std::move(name), interfaces);
 }
+
 void SPIRVModule::addExecutionMode(Instruction &&execMode) {
   executionModes.push_back(std::move(execMode));
 }
+
 void SPIRVModule::addDebugName(uint32_t targetId, llvm::StringRef name,
                                llvm::Optional<uint32_t> memberIndex) {
   if (!name.empty()) {
     debugNames.emplace_back(targetId, name, memberIndex);
   }
 }
+
 void SPIRVModule::addDecoration(const Decoration &decoration,
                                 uint32_t targetId) {
   decorations.emplace_back(decoration, targetId);
 }
+
 void SPIRVModule::addType(const Type *type, uint32_t resultId) {
   types.insert(std::make_pair(type, resultId));
 }
+
 void SPIRVModule::addConstant(const Constant *constant, uint32_t resultId) {
   constants.insert(std::make_pair(constant, resultId));
 };
+
 void SPIRVModule::addVariable(Instruction &&var) {
   variables.push_back(std::move(var));
 }
+
 void SPIRVModule::addFunction(std::unique_ptr<Function> f) {
   functions.push_back(std::move(f));
 }

+ 84 - 55
tools/clang/lib/SPIRV/Structure.cpp

@@ -15,22 +15,34 @@ namespace spirv {
 namespace {
 constexpr uint32_t kGeneratorNumber = 14;
 constexpr uint32_t kToolVersion = 0;
+} // namespace
 
-bool isTerminator(spv::Op opcode) {
-  switch (opcode) {
-  case spv::Op::OpKill:
-  case spv::Op::OpUnreachable:
+// === Instruction implementations ===
+
+spv::Op Instruction::getOpcode() const {
+  if (!isEmpty()) {
+    return static_cast<spv::Op>(words.front() & spv::OpCodeMask);
+  }
+
+  return spv::Op::Max;
+}
+
+bool Instruction::isTerminator() const {
+  switch (getOpcode()) {
   case spv::Op::OpBranch:
   case spv::Op::OpBranchConditional:
-  case spv::Op::OpSwitch:
   case spv::Op::OpReturn:
   case spv::Op::OpReturnValue:
+  case spv::Op::OpSwitch:
+  case spv::Op::OpKill:
+  case spv::Op::OpUnreachable:
     return true;
   default:
     return false;
   }
 }
-} // namespace
+
+// === Basic block implementations ===
 
 BasicBlock::BasicBlock(BasicBlock &&that)
     : labelId(that.labelId), instructions(std::move(that.instructions)) {
@@ -46,29 +58,35 @@ BasicBlock &BasicBlock::operator=(BasicBlock &&that) {
   return *this;
 }
 
+void BasicBlock::clear() {
+  labelId = 0;
+  instructions.clear();
+}
+
 void BasicBlock::take(InstBuilder *builder) {
   // Make sure we have a terminator instruction at the end.
-  // TODO: This is a little bit ugly. It suggests that we should put the opcode
-  // in the Instruction struct. But fine for now.
   assert(isTerminated() && "found basic block without terminator");
+
   builder->opLabel(labelId).x();
+
   for (auto &inst : instructions) {
-    builder->getConsumer()(std::move(inst));
+    builder->getConsumer()(inst.take());
   }
+
   clear();
 }
 
 bool BasicBlock::isTerminated() const {
-  return !instructions.empty() &&
-         isTerminator(
-             // Take the last 16 bits and convert it into opcode
-             static_cast<spv::Op>(instructions.back().front() & 0xffff));
+  return !instructions.empty() && instructions.back().isTerminator();
 }
 
+// === Function implementations ===
+
 Function::Function(Function &&that)
     : resultType(that.resultType), resultId(that.resultId),
       funcControl(that.funcControl), funcType(that.funcType),
-      parameters(std::move(that.parameters)), blocks(std::move(that.blocks)) {
+      parameters(std::move(that.parameters)),
+      variables(std::move(that.variables)), blocks(std::move(that.blocks)) {
   that.clear();
 }
 
@@ -96,14 +114,6 @@ void Function::clear() {
   blocks.clear();
 }
 
-void Function::addVariable(uint32_t varType, uint32_t varId,
-                           llvm::Optional<uint32_t> init) {
-  variables.emplace_back(
-      InstBuilder(nullptr)
-          .opVariable(varType, varId, spv::StorageClass::Function, init)
-          .take());
-}
-
 void Function::take(InstBuilder *builder) {
   builder->opFunction(resultType, resultId, funcControl, funcType).x();
 
@@ -130,6 +140,16 @@ void Function::take(InstBuilder *builder) {
   clear();
 }
 
+void Function::addVariable(uint32_t varType, uint32_t varId,
+                           llvm::Optional<uint32_t> init) {
+  variables.emplace_back(
+      InstBuilder(nullptr)
+          .opVariable(varType, varId, spv::StorageClass::Function, init)
+          .take());
+}
+
+// === Module components implementations ===
+
 Header::Header()
     : magicNumber(spv::MagicNumber), version(spv::Version),
       generator((kGeneratorNumber << 16) | kToolVersion), bound(0),
@@ -145,16 +165,20 @@ void Header::collect(const WordConsumer &consumer) {
   consumer(std::move(words));
 }
 
+// === Module implementations ===
+
 bool SPIRVModule::isEmpty() const {
   return header.bound == 0 && capabilities.empty() && extensions.empty() &&
          extInstSets.empty() && !addressingModel.hasValue() &&
          !memoryModel.hasValue() && entryPoints.empty() &&
          executionModes.empty() && debugNames.empty() && decorations.empty() &&
+         types.empty() && constants.empty() && variables.empty() &&
          functions.empty();
 }
 
 void SPIRVModule::clear() {
   header.bound = 0;
+
   capabilities.clear();
   extensions.clear();
   extInstSets.clear();
@@ -164,43 +188,18 @@ void SPIRVModule::clear() {
   executionModes.clear();
   debugNames.clear();
   decorations.clear();
-  functions.clear();
-}
-
-void SPIRVModule::takeIntegerTypes(InstBuilder *ib) {
-  const auto &consumer = ib->getConsumer();
-  // If it finds any integer type, feeds it into the consumer, and removes it
-  // from the types collection.
-  types.remove_if([&consumer](std::pair<const Type *, uint32_t> &item) {
-    const bool isInteger = item.first->isIntegerType();
-    if (isInteger)
-      consumer(item.first->withResultId(item.second));
-    return isInteger;
-  });
-}
-
-void SPIRVModule::takeConstantForArrayType(const Type *arrType,
-                                           InstBuilder *ib) {
-  assert(arrType->isArrayType() &&
-         "takeConstantForArrayType was called with a non-array type.");
-  const auto &consumer = ib->getConsumer();
-  const uint32_t arrayLengthResultId = arrType->getArgs().back();
+  types.clear();
+  constants.clear();
+  variables.clear();
 
-  // If it finds the constant, feeds it into the consumer, and removes it
-  // from the constants collection.
-  constants.remove_if([&consumer, arrayLengthResultId](
-      std::pair<const Constant *, uint32_t> &item) {
-    const bool isArrayLengthConstant = (item.second == arrayLengthResultId);
-    if (isArrayLengthConstant)
-      consumer(item.first->withResultId(item.second));
-    return isArrayLengthConstant;
-  });
+  functions.clear();
 }
 
 void SPIRVModule::take(InstBuilder *builder) {
   const auto &consumer = builder->getConsumer();
 
   // Order matters here.
+
   header.collect(consumer);
 
   for (auto &cap : capabilities) {
@@ -227,7 +226,7 @@ void SPIRVModule::take(InstBuilder *builder) {
   }
 
   for (auto &inst : executionModes) {
-    consumer(std::move(inst));
+    consumer(inst.take());
   }
 
   for (auto &inst : debugNames) {
@@ -257,7 +256,7 @@ void SPIRVModule::take(InstBuilder *builder) {
     // If we have an array type, we must first define the integer constant that
     // defines its length.
     if (t.first->isArrayType()) {
-      takeConstantForArrayType(t.first, builder);
+      takeConstantForArrayType(*t.first, builder);
     }
 
     consumer(t.first->withResultId(t.second));
@@ -268,7 +267,7 @@ void SPIRVModule::take(InstBuilder *builder) {
   }
 
   for (auto &v : variables) {
-    consumer(std::move(v));
+    consumer(v.take());
   }
 
   for (uint32_t i = 0; i < functions.size(); ++i) {
@@ -278,5 +277,35 @@ void SPIRVModule::take(InstBuilder *builder) {
   clear();
 }
 
+void SPIRVModule::takeIntegerTypes(InstBuilder *ib) {
+  const auto &consumer = ib->getConsumer();
+  // If it finds any integer type, feeds it into the consumer, and removes it
+  // from the types collection.
+  types.remove_if([&consumer](std::pair<const Type *, uint32_t> &item) {
+    const bool isInteger = item.first->isIntegerType();
+    if (isInteger)
+      consumer(item.first->withResultId(item.second));
+    return isInteger;
+  });
+}
+
+void SPIRVModule::takeConstantForArrayType(const Type &arrType,
+                                           InstBuilder *ib) {
+  assert(arrType.isArrayType());
+
+  const auto &consumer = ib->getConsumer();
+  const uint32_t arrayLengthResultId = arrType.getArgs().back();
+
+  // If it finds the constant, feeds it into the consumer, and removes it
+  // from the constants collection.
+  constants.remove_if([&consumer, arrayLengthResultId](
+      std::pair<const Constant *, uint32_t> &item) {
+    const bool isArrayLengthConstant = (item.second == arrayLengthResultId);
+    if (isArrayLengthConstant)
+      consumer(item.first->withResultId(item.second));
+    return isArrayLengthConstant;
+  });
+}
+
 } // end namespace spirv
 } // end namespace clang

+ 5 - 3
tools/clang/unittests/SPIRV/CMakeLists.txt

@@ -18,7 +18,7 @@ add_clang_unittest(clang-spirv-tests
   TestMain.cpp
   StringTest.cpp
   TypeTest.cpp
-  WholeFileCheck.cpp
+  WholeFileTestFixture.cpp
   )
 
 target_link_libraries(clang-spirv-tests
@@ -29,6 +29,8 @@ target_link_libraries(clang-spirv-tests
   SPIRV-Tools
   )
 
-target_include_directories(clang-spirv-tests PRIVATE ${SPIRV_TOOLS_INCLUDE_DIR} ${DXC_EFFCEE_DIR})
+target_include_directories(clang-spirv-tests
+  PRIVATE ${SPIRV_TOOLS_INCLUDE_DIR} ${DXC_EFFCEE_DIR})
 
-set_output_directory(clang-spirv-tests ${LLVM_RUNTIME_OUTPUT_INTDIR} ${LLVM_LIBRARY_OUTPUT_INTDIR})
+set_output_directory(clang-spirv-tests
+  ${LLVM_RUNTIME_OUTPUT_INTDIR} ${LLVM_LIBRARY_OUTPUT_INTDIR})

+ 1 - 1
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -8,7 +8,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "FileTestFixture.h"
-#include "WholeFileCheck.h"
+#include "WholeFileTestFixture.h"
 
 namespace {
 using clang::spirv::FileTest;

+ 8 - 8
tools/clang/unittests/SPIRV/InstBuilderTest.cpp

@@ -169,20 +169,20 @@ TEST(InstBuilder, InstWStringParams) {
   expectBuildSuccess(ib.opString(5, "main").x());
   expectBuildSuccess(ib.opString(6, "mainf").x());
 
-  std::vector<uint32_t> expected;
+  SimpleInstBuilder sib;
   uint32_t strWord = 0;
-  appendVector(&expected, constructInst(spv::Op::OpString, {1, strWord}));
+  sib.inst(spv::Op::OpString, {1, strWord});
   strWord = 'm';
-  appendVector(&expected, constructInst(spv::Op::OpString, {2, strWord}));
+  sib.inst(spv::Op::OpString, {2, strWord});
   strWord |= 'a' << 8;
-  appendVector(&expected, constructInst(spv::Op::OpString, {3, strWord}));
+  sib.inst(spv::Op::OpString, {3, strWord});
   strWord |= 'i' << 16;
-  appendVector(&expected, constructInst(spv::Op::OpString, {4, strWord}));
+  sib.inst(spv::Op::OpString, {4, strWord});
   strWord |= 'n' << 24;
-  appendVector(&expected, constructInst(spv::Op::OpString, {5, strWord, 0}));
-  appendVector(&expected, constructInst(spv::Op::OpString, {6, strWord, 'f'}));
+  sib.inst(spv::Op::OpString, {5, strWord, 0});
+  sib.inst(spv::Op::OpString, {6, strWord, 'f'});
 
-  EXPECT_THAT(result, ContainerEq(expected));
+  EXPECT_THAT(result, ContainerEq(sib.get()));
 }
 // TOOD: Add tests for providing more parameters than needed
 

+ 10 - 12
tools/clang/unittests/SPIRV/ModuleBuilderTest.cpp

@@ -39,11 +39,10 @@ TEST(ModuleBuilder, CreateFunction) {
   EXPECT_TRUE(builder.endFunction());
   const auto result = builder.takeModule();
 
-  auto expected = getModuleHeader(context.getNextId());
-  appendVector(&expected,
-               constructInst(spv::Op::OpFunction, {rType, fId, 0, fType}));
-  appendVector(&expected, constructInst(spv::Op::OpFunctionEnd, {}));
-  EXPECT_THAT(result, ContainerEq(expected));
+  SimpleInstBuilder sib(context.getNextId());
+  sib.inst(spv::Op::OpFunction, {rType, fId, 0, fType});
+  sib.inst(spv::Op::OpFunctionEnd, {});
+  EXPECT_THAT(result, ContainerEq(sib.get()));
 }
 
 TEST(ModuleBuilder, CreateBasicBlock) {
@@ -63,14 +62,13 @@ TEST(ModuleBuilder, CreateBasicBlock) {
 
   const auto result = builder.takeModule();
 
-  auto expected = getModuleHeader(context.getNextId());
-  appendVector(&expected,
-               constructInst(spv::Op::OpFunction, {rType, fId, 0, fType}));
-  appendVector(&expected, constructInst(spv::Op::OpLabel, {labelId}));
-  appendVector(&expected, constructInst(spv::Op::OpReturn, {}));
-  appendVector(&expected, constructInst(spv::Op::OpFunctionEnd, {}));
+  SimpleInstBuilder sib(context.getNextId());
+  sib.inst(spv::Op::OpFunction, {rType, fId, 0, fType});
+  sib.inst(spv::Op::OpLabel, {labelId});
+  sib.inst(spv::Op::OpReturn, {});
+  sib.inst(spv::Op::OpFunctionEnd, {});
 
-  EXPECT_THAT(result, ContainerEq(expected));
+  EXPECT_THAT(result, ContainerEq(sib.get()));
 }
 
 } // anonymous namespace

+ 4 - 4
tools/clang/unittests/SPIRV/SPIRVContextTest.cpp

@@ -17,7 +17,7 @@ using namespace clang::spirv;
 
 namespace {
 
-TEST(ValidateSPIRVContext, ValidateGetNextId) {
+TEST(SPIRVContext, GetNextId) {
   SPIRVContext context;
   // Check that the first ID is 1.
   EXPECT_EQ(context.getNextId(), 1u);
@@ -25,14 +25,14 @@ TEST(ValidateSPIRVContext, ValidateGetNextId) {
   EXPECT_EQ(context.getNextId(), 1u);
 }
 
-TEST(ValidateSPIRVContext, ValidateTakeNextId) {
+TEST(SPIRVContext, TakeNextId) {
   SPIRVContext context;
   EXPECT_EQ(context.takeNextId(), 1u);
   EXPECT_EQ(context.takeNextId(), 2u);
   EXPECT_EQ(context.getNextId(), 3u);
 }
 
-TEST(ValidateSPIRVContext, ValidateUniqueIdForUniqueNonAggregateType) {
+TEST(SPIRVContext, UniqueIdForUniqueNonAggregateType) {
   SPIRVContext ctx;
   const Type *intt = Type::getInt32(ctx);
   uint32_t intt_id = ctx.getResultIdForType(intt);
@@ -41,7 +41,7 @@ TEST(ValidateSPIRVContext, ValidateUniqueIdForUniqueNonAggregateType) {
   EXPECT_EQ(intt_id, intt_id_again);
 }
 
-TEST(ValidateSPIRVContext, ValidateUniqueIdForUniqueAggregateType) {
+TEST(SPIRVContext, UniqueIdForUniqueAggregateType) {
   SPIRVContext ctx;
   // In this test we construct a struct which includes an integer member and
   // a boolean member.

+ 33 - 14
tools/clang/unittests/SPIRV/SPIRVTestUtils.h

@@ -18,8 +18,9 @@
 #include "gmock/gmock.h"
 #include "gtest/gtest.h"
 
-#include "clang/SPIRV/InstBuilder.h"
 #include "spirv/1.0/spirv.hpp11"
+#include "clang/SPIRV/InstBuilder.h"
+#include "llvm/ADT/ArrayRef.h"
 
 namespace clang {
 namespace spirv {
@@ -32,33 +33,51 @@ inline InstBuilder constructInstBuilder(std::vector<uint32_t> &binary) {
   });
 }
 
+/// Returns the words in SPIR-V module header with the given id bound.
+inline std::vector<uint32_t> getModuleHeader(uint32_t bound) {
+  return {spv::MagicNumber, spv::Version, 14u << 16, bound, 0};
+}
+
 /// Creates a SPIR-V instruction.
-inline std::vector<uint32_t>
-constructInst(spv::Op opcode, std::initializer_list<uint32_t> params) {
+inline std::vector<uint32_t> constructInst(spv::Op opcode,
+                                           llvm::ArrayRef<uint32_t> params) {
   std::vector<uint32_t> words;
+
   words.push_back(static_cast<uint32_t>(opcode));
   for (auto w : params) {
     words.push_back(w);
   }
   words.front() |= static_cast<uint32_t>(words.size()) << 16;
+
   return words;
 }
 
+/// A simple instruction builder for testing purpose.
+class SimpleInstBuilder {
+public:
+  /// Constructs a simple instruction builder with no module header.
+  SimpleInstBuilder() {}
+
+  /// Constructs a simple instruction builder with module header having the
+  /// given id bound.
+  explicit SimpleInstBuilder(uint32_t bound) : words(getModuleHeader(bound)) {}
+
+  /// Adds an instruction.
+  void inst(spv::Op opcode, llvm::ArrayRef<uint32_t> params) {
+    auto inst = constructInst(opcode, params);
+    words.insert(words.end(), inst.begin(), inst.end());
+  }
+
+  const std::vector<uint32_t> &get() const { return words; }
+
+private:
+  std::vector<uint32_t> words;
+};
+
 /// Expects the given status is success.
 inline void expectBuildSuccess(InstBuilder::Status status) {
   EXPECT_EQ(InstBuilder::Status::Success, status);
 }
 
-/// Appends the part vector to the end of the all vector.
-inline void appendVector(std::vector<uint32_t> *all,
-                         const std::vector<uint32_t> &part) {
-  all->insert(all->end(), part.begin(), part.end());
-}
-
-/// Returns the words in SPIR-V module header with the given id bound.
-inline std::vector<uint32_t> getModuleHeader(uint32_t bound) {
-  return {spv::MagicNumber, spv::Version, 14u << 16, bound, 0};
-}
-
 } // end namespace spirv
 } // end namespace clang

+ 93 - 79
tools/clang/unittests/SPIRV/StructureTest.cpp

@@ -18,9 +18,37 @@ using namespace clang::spirv;
 
 using ::testing::ContainerEq;
 
-TEST(Structure, DefaultConstructedBasicBlockIsEmpty) {
-  auto bb = BasicBlock();
-  EXPECT_TRUE(bb.isEmpty());
+TEST(Structure, InstructionHasCorrectOpcode) {
+  Instruction inst(constructInst(spv::Op::OpIAdd, {1, 2, 3, 4}));
+
+  ASSERT_TRUE(!inst.isEmpty());
+  EXPECT_EQ(inst.getOpcode(), spv::Op::OpIAdd);
+}
+
+TEST(Structure, InstructionGetOriginalContents) {
+  Instruction inst(constructInst(spv::Op::OpIAdd, {1, 2, 3, 4}));
+
+  EXPECT_THAT(inst.take(),
+              ContainerEq(constructInst(spv::Op::OpIAdd, {1, 2, 3, 4})));
+}
+
+TEST(Structure, InstructionIsTerminator) {
+  for (auto opcode :
+       {spv::Op::OpKill, spv::Op::OpUnreachable, spv::Op::OpBranch,
+        spv::Op::OpBranchConditional, spv::Op::OpSwitch, spv::Op::OpReturn,
+        spv::Op::OpReturnValue}) {
+    Instruction inst(constructInst(opcode, {/* wrong params here */ 1}));
+
+    EXPECT_TRUE(inst.isTerminator());
+  }
+}
+
+TEST(Structure, InstructionIsNotTerminator) {
+  for (auto opcode : {spv::Op::OpNop, spv::Op::OpAccessChain, spv::Op::OpAll}) {
+    Instruction inst(constructInst(opcode, {/* wrong params here */ 1}));
+
+    EXPECT_FALSE(inst.isTerminator());
+  }
 }
 
 TEST(Structure, TakeBasicBlockHaveAllContents) {
@@ -31,11 +59,11 @@ TEST(Structure, TakeBasicBlockHaveAllContents) {
   bb.appendInstruction(constructInst(spv::Op::OpReturn, {}));
   bb.take(&ib);
 
-  std::vector<uint32_t> expected;
-  appendVector(&expected, constructInst(spv::Op::OpLabel, {42}));
-  appendVector(&expected, constructInst(spv::Op::OpReturn, {}));
+  SimpleInstBuilder sib;
+  sib.inst(spv::Op::OpLabel, {42});
+  sib.inst(spv::Op::OpReturn, {});
 
-  EXPECT_THAT(result, ContainerEq(expected));
+  EXPECT_THAT(result, ContainerEq(sib.get()));
   EXPECT_TRUE(bb.isEmpty());
 }
 
@@ -47,11 +75,6 @@ TEST(Structure, AfterClearBasicBlockIsEmpty) {
   EXPECT_TRUE(bb.isEmpty());
 }
 
-TEST(Structure, DefaultConstructedFunctionIsEmpty) {
-  auto f = Function();
-  EXPECT_TRUE(f.isEmpty());
-}
-
 TEST(Structure, TakeFunctionHaveAllContents) {
   auto f = Function(1, 2, spv::FunctionControlMask::Inline, 3);
   f.addParameter(1, 42);
@@ -64,14 +87,14 @@ TEST(Structure, TakeFunctionHaveAllContents) {
   auto ib = constructInstBuilder(result);
   f.take(&ib);
 
-  std::vector<uint32_t> expected;
-  appendVector(&expected, constructInst(spv::Op::OpFunction, {1, 2, 1, 3}));
-  appendVector(&expected, constructInst(spv::Op::OpFunctionParameter, {1, 42}));
-  appendVector(&expected, constructInst(spv::Op::OpLabel, {10}));
-  appendVector(&expected, constructInst(spv::Op::OpReturn, {}));
-  appendVector(&expected, constructInst(spv::Op::OpFunctionEnd, {}));
+  SimpleInstBuilder sib;
+  sib.inst(spv::Op::OpFunction, {1, 2, 1, 3});
+  sib.inst(spv::Op::OpFunctionParameter, {1, 42});
+  sib.inst(spv::Op::OpLabel, {10});
+  sib.inst(spv::Op::OpReturn, {});
+  sib.inst(spv::Op::OpFunctionEnd, {});
 
-  EXPECT_THAT(result, ContainerEq(expected));
+  EXPECT_THAT(result, ContainerEq(sib.get()));
   EXPECT_TRUE(f.isEmpty());
 }
 
@@ -101,88 +124,79 @@ TEST(Structure, TakeModuleHaveAllContents) {
   auto m = SPIRVModule();
 
   // Will fix up the bound later.
-  std::vector<uint32_t> expected = getModuleHeader(0);
+  SimpleInstBuilder sib(0);
 
   m.addCapability(spv::Capability::Shader);
-  appendVector(&expected,
-               constructInst(spv::Op::OpCapability,
-                             {static_cast<uint32_t>(spv::Capability::Shader)}));
+  sib.inst(spv::Op::OpCapability,
+           {static_cast<uint32_t>(spv::Capability::Shader)});
 
   m.addExtension("ext");
   const uint32_t extWord = 'e' | ('x' << 8) | ('t' << 16);
-  appendVector(&expected, constructInst(spv::Op::OpExtension, {extWord}));
+  sib.inst(spv::Op::OpExtension, {extWord});
 
   const uint32_t extInstSetId = context.takeNextId();
   m.addExtInstSet(extInstSetId, "gl");
   const uint32_t glWord = 'g' | ('l' << 8);
-  appendVector(&expected,
-               constructInst(spv::Op::OpExtInstImport, {extInstSetId, glWord}));
+  sib.inst(spv::Op::OpExtInstImport, {extInstSetId, glWord});
 
   m.setAddressingModel(spv::AddressingModel::Logical);
   m.setMemoryModel(spv::MemoryModel::GLSL450);
-  appendVector(
-      &expected,
-      constructInst(spv::Op::OpMemoryModel,
-                    {static_cast<uint32_t>(spv::AddressingModel::Logical),
-                     static_cast<uint32_t>(spv::MemoryModel::GLSL450)}));
+  sib.inst(spv::Op::OpMemoryModel,
+           {static_cast<uint32_t>(spv::AddressingModel::Logical),
+            static_cast<uint32_t>(spv::MemoryModel::GLSL450)});
 
   const uint32_t entryPointId = context.takeNextId();
   m.addEntryPoint(spv::ExecutionModel::Fragment, entryPointId, "main", {42});
   const uint32_t mainWord = 'm' | ('a' << 8) | ('i' << 16) | ('n' << 24);
-  appendVector(
-      &expected,
-      constructInst(spv::Op::OpEntryPoint,
-                    {static_cast<uint32_t>(spv::ExecutionModel::Fragment),
-                     entryPointId, mainWord, /* addtional null in name */ 0,
-                     42}));
+  sib.inst(spv::Op::OpEntryPoint,
+           {static_cast<uint32_t>(spv::ExecutionModel::Fragment), entryPointId,
+            mainWord, /* addtional null in name */ 0, 42});
 
   m.addExecutionMode(constructInst(
       spv::Op::OpExecutionMode,
       {entryPointId,
        static_cast<uint32_t>(spv::ExecutionMode::OriginUpperLeft)}));
-  appendVector(
-      &expected,
-      constructInst(spv::Op::OpExecutionMode,
-                    {entryPointId, static_cast<uint32_t>(
-                                       spv::ExecutionMode::OriginUpperLeft)}));
+  sib.inst(spv::Op::OpExecutionMode,
+           {entryPointId,
+            static_cast<uint32_t>(spv::ExecutionMode::OriginUpperLeft)});
 
   // TODO: source code debug information
 
   m.addDebugName(entryPointId, "main");
-  appendVector(&expected, constructInst(spv::Op::OpName,
-                                        {entryPointId, mainWord,
-                                         /* additional null in name */ 0}));
+  sib.inst(spv::Op::OpName,
+           {entryPointId, mainWord, /* additional null in name */ 0});
 
   m.addDecoration(*Decoration::getRelaxedPrecision(context), entryPointId);
-
-  appendVector(&expected,
-               constructInst(
-                   spv::Op::OpDecorate,
-                   {entryPointId,
-                    static_cast<uint32_t>(spv::Decoration::RelaxedPrecision)}));
+  sib.inst(
+      spv::Op::OpDecorate,
+      {entryPointId, static_cast<uint32_t>(spv::Decoration::RelaxedPrecision)});
 
   const auto *i32Type = Type::getInt32(context);
   const uint32_t i32Id = context.getResultIdForType(i32Type);
   m.addType(i32Type, i32Id);
-  appendVector(&expected, constructInst(spv::Op::OpTypeInt, {i32Id, 32, 1}));
+  sib.inst(spv::Op::OpTypeInt, {i32Id, 32, 1});
 
   const auto *voidType = Type::getVoid(context);
   const uint32_t voidId = context.getResultIdForType(voidType);
   m.addType(voidType, voidId);
-  appendVector(&expected, constructInst(spv::Op::OpTypeVoid, {voidId}));
+  sib.inst(spv::Op::OpTypeVoid, {voidId});
 
   const auto *funcType = Type::getFunction(context, voidId, {voidId});
   const uint32_t funcTypeId = context.getResultIdForType(funcType);
   m.addType(funcType, funcTypeId);
-  appendVector(&expected, constructInst(spv::Op::OpTypeFunction,
-                                        {funcTypeId, voidId, voidId}));
+  sib.inst(spv::Op::OpTypeFunction, {funcTypeId, voidId, voidId});
 
   const auto *i32Const = Constant::getInt32(context, i32Id, 42);
   const uint32_t constantId = context.takeNextId();
   m.addConstant(i32Const, constantId);
-  appendVector(&expected,
-               constructInst(spv::Op::OpConstant, {i32Id, constantId, 42}));
-  // TODO: global variable
+  sib.inst(spv::Op::OpConstant, {i32Id, constantId, 42});
+
+  const uint32_t varId = context.takeNextId();
+  m.addVariable(constructInst(
+      spv::Op::OpVariable,
+      {i32Id, varId, static_cast<uint32_t>(spv::StorageClass::Input)}));
+  sib.inst(spv::Op::OpVariable,
+           {i32Id, varId, static_cast<uint32_t>(spv::StorageClass::Input)});
 
   const uint32_t funcId = context.takeNextId();
   auto f = llvm::make_unique<Function>(
@@ -192,13 +206,14 @@ TEST(Structure, TakeModuleHaveAllContents) {
   bb->appendInstruction(constructInst(spv::Op::OpReturn, {}));
   f->addBasicBlock(std::move(bb));
   m.addFunction(std::move(f));
-  appendVector(&expected, constructInst(spv::Op::OpFunction,
-                                        {voidId, funcId, 0, funcTypeId}));
-  appendVector(&expected, constructInst(spv::Op::OpLabel, {bbId}));
-  appendVector(&expected, constructInst(spv::Op::OpReturn, {}));
-  appendVector(&expected, constructInst(spv::Op::OpFunctionEnd, {}));
+  sib.inst(spv::Op::OpFunction, {voidId, funcId, 0, funcTypeId});
+  sib.inst(spv::Op::OpLabel, {bbId});
+  sib.inst(spv::Op::OpReturn, {});
+  sib.inst(spv::Op::OpFunctionEnd, {});
 
   m.setBound(context.getNextId());
+
+  std::vector<uint32_t> expected = sib.get();
   expected[3] = context.getNextId();
 
   std::vector<uint32_t> result;
@@ -212,7 +227,9 @@ TEST(Structure, TakeModuleHaveAllContents) {
 TEST(Structure, TakeModuleWithArrayAndConstantDependency) {
   SPIRVContext context;
   auto m = SPIRVModule();
-  std::vector<uint32_t> expected = getModuleHeader(0);
+
+  // Will fix up the id bound later.
+  SimpleInstBuilder sib(0);
 
   // Add void type
   const auto *voidType = Type::getVoid(context);
@@ -256,22 +273,19 @@ TEST(Structure, TakeModuleWithArrayAndConstantDependency) {
   m.setBound(context.getNextId());
 
   // Decorations
-  appendVector(
-      &expected,
-      constructInst(spv::Op::OpDecorate,
-                    {secondArrId,
-                     static_cast<uint32_t>(spv::Decoration::ArrayStride), 4}));
+  sib.inst(
+      spv::Op::OpDecorate,
+      {secondArrId, static_cast<uint32_t>(spv::Decoration::ArrayStride), 4});
   // Now the expected order: int64, int32, void, float, constant(8), array
-  appendVector(&expected, constructInst(spv::Op::OpTypeInt, {i64Id, 64, 1}));
-  appendVector(&expected, constructInst(spv::Op::OpTypeInt, {i32Id, 32, 1}));
-  appendVector(&expected, constructInst(spv::Op::OpTypeVoid, {voidId}));
-  appendVector(&expected, constructInst(spv::Op::OpTypeFloat, {f32Id, 32}));
-  appendVector(&expected,
-               constructInst(spv::Op::OpConstant, {i32Id, constantId, 8}));
-  appendVector(&expected,
-               constructInst(spv::Op::OpTypeArray, {arrId, i32Id, constantId}));
-  appendVector(&expected, constructInst(spv::Op::OpTypeArray,
-                                        {secondArrId, i32Id, constantId}));
+  sib.inst(spv::Op::OpTypeInt, {i64Id, 64, 1});
+  sib.inst(spv::Op::OpTypeInt, {i32Id, 32, 1});
+  sib.inst(spv::Op::OpTypeVoid, {voidId});
+  sib.inst(spv::Op::OpTypeFloat, {f32Id, 32});
+  sib.inst(spv::Op::OpConstant, {i32Id, constantId, 8});
+  sib.inst(spv::Op::OpTypeArray, {arrId, i32Id, constantId});
+  sib.inst(spv::Op::OpTypeArray, {secondArrId, i32Id, constantId});
+
+  std::vector<uint32_t> expected = sib.get();
   expected[3] = context.getNextId();
 
   std::vector<uint32_t> result;

+ 1 - 1
tools/clang/unittests/SPIRV/TestMain.cpp

@@ -1,4 +1,4 @@
-//===--- utils/unittest/UnitTestMain/TestMain.cpp - unittest driver -------===//
+//===--- utils/unittest/SPIRV/TestMain.cpp - unittest driver --------------===//
 //
 //                     The LLVM Compiler Infrastructure
 //

+ 2 - 2
tools/clang/unittests/SPIRV/WholeFileCheck.cpp → tools/clang/unittests/SPIRV/WholeFileTestFixture.cpp

@@ -1,4 +1,4 @@
-//===- unittests/SPIRV/WholeFileCheck.cpp - WholeFileCheck Implementation -===//
+//===- unittests/SPIRV/WholeFileTestFixture.cpp - WholeFileTest impl ------===//
 //
 //                     The LLVM Compiler Infrastructure
 //
@@ -10,7 +10,7 @@
 #include <fstream>
 
 #include "FileTestUtils.h"
-#include "WholeFileCheck.h"
+#include "WholeFileTestFixture.h"
 
 namespace clang {
 namespace spirv {

+ 3 - 3
tools/clang/unittests/SPIRV/WholeFileCheck.h → tools/clang/unittests/SPIRV/WholeFileTestFixture.h

@@ -1,4 +1,4 @@
-//===- unittests/SPIRV/WholeFileCheck.h ---- WholeFileCheck Test Fixture --===//
+//===- unittests/SPIRV/WholeFileTestFixture.h - Whole file test Fixture ---===//
 //
 //                     The LLVM Compiler Infrastructure
 //
@@ -7,8 +7,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef LLVM_CLANG_UNITTESTS_SPIRV_WHOLEFILECHECK_H
-#define LLVM_CLANG_UNITTESTS_SPIRV_WHOLEFILECHECK_H
+#ifndef LLVM_CLANG_UNITTESTS_SPIRV_WHOLEFILETESTFIXTURE_H
+#define LLVM_CLANG_UNITTESTS_SPIRV_WHOLEFILETESTFIXTURE_H
 
 #include "llvm/ADT/StringRef.h"
 #include "gtest/gtest.h"