Browse Source

[spirv] Fully translate pass-through vertex shader! (#445)

* Handled structs as entry point output type
* Handled local variable of struct type
* Replaced the term of "interface variable" with "stage variable"
Lei Zhang 8 years ago
parent
commit
2562dab026

+ 9 - 1
tools/clang/include/clang/SPIRV/InstBuilder.h

@@ -80,8 +80,12 @@ public:
   void setConsumer(WordConsumer);
   void setConsumer(WordConsumer);
   const WordConsumer &getConsumer() const;
   const WordConsumer &getConsumer() const;
 
 
-  /// \brief Finalizes the building.
+  /// \brief Finalizes the building and feeds the generated SPIR-V words
+  /// to the consumer.
   Status x();
   Status x();
+  /// \brief Finalizes the building and returns the generated SPIR-V words.
+  /// Returns an empty vector if errors happened during the construction.
+  std::vector<uint32_t> take();
   /// \brief Clears the current instruction under building.
   /// \brief Clears the current instruction under building.
   void clear();
   void clear();
 
 
@@ -787,6 +791,10 @@ public:
                                            uint32_t result_id, uint32_t value,
                                            uint32_t result_id, uint32_t value,
                                            uint32_t index);
                                            uint32_t index);
 
 
+  // Methods for building constants.
+  InstBuilder &opConstant(uint32_t result_type, uint32_t result_id,
+                          uint32_t value);
+
   // Methods for supplying additional parameters.
   // Methods for supplying additional parameters.
   InstBuilder &fPFastMathMode(spv::FPFastMathModeMask);
   InstBuilder &fPFastMathMode(spv::FPFastMathModeMask);
   InstBuilder &fPRoundingMode(spv::FPRoundingMode);
   InstBuilder &fPRoundingMode(spv::FPRoundingMode);

+ 28 - 6
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -45,9 +45,16 @@ public:
                          std::string name = "");
                          std::string name = "");
 
 
   /// \brief Registers a function parameter of the given type onto the current
   /// \brief Registers a function parameter of the given type onto the current
-  /// function under construction and returns its <result-id>.
+  /// function and returns its <result-id>.
   uint32_t addFnParameter(uint32_t type);
   uint32_t addFnParameter(uint32_t type);
 
 
+  /// \brief Creates a local variable of the given value type in the current
+  /// function and returns its <result-id>.
+  ///
+  /// The corresponding pointer type of the given value type will be constructed
+  /// for the variable itself.
+  uint32_t addFnVariable(uint32_t valueType);
+
   /// \brief Ends building of the current function. Returns true of success,
   /// \brief Ends building of the current function. Returns true of success,
   /// false on failure. All basic blocks constructed from the beginning or
   /// false on failure. All basic blocks constructed from the beginning or
   /// after ending the previous function will be collected into this function.
   /// after ending the previous function will be collected into this function.
@@ -75,6 +82,12 @@ public:
   /// address.
   /// address.
   void createStore(uint32_t address, uint32_t value);
   void createStore(uint32_t address, uint32_t value);
 
 
+  /// \brief Creates an access chain instruction to retrieve the element from
+  /// the given base by walking through the given indexes. Returns the
+  /// <result-id> for the pointer to the element.
+  uint32_t createAccessChain(uint32_t resultType, uint32_t base,
+                             llvm::ArrayRef<uint32_t> indexes);
+
   /// \brief Creates a return instruction.
   /// \brief Creates a return instruction.
   void createReturn();
   void createReturn();
   /// \brief Creates a return value instruction.
   /// \brief Creates a return value instruction.
@@ -101,8 +114,13 @@ public:
   ///
   ///
   /// The corresponding pointer type of the given type will be constructed in
   /// The corresponding pointer type of the given type will be constructed in
   /// this method for the variable itself.
   /// this method for the variable itself.
-  uint32_t addStageIOVariable(uint32_t type, spv::StorageClass storageClass,
-                              llvm::Optional<uint32_t> initializer);
+  uint32_t addStageIOVariable(uint32_t type, spv::StorageClass storageClass);
+
+  /// \brief Adds a stage builtin variable whose value is of the given type.
+  ///
+  /// The corresponding pointer type of the given type will be constructed in
+  /// this method for the variable itself.
+  uint32_t addStageBuiltinVariable(uint32_t type, spv::BuiltIn);
 
 
   /// \brief Decorates the given target <result-id> with the given location.
   /// \brief Decorates the given target <result-id> with the given location.
   void decorateLocation(uint32_t targetId, uint32_t location);
   void decorateLocation(uint32_t targetId, uint32_t location);
@@ -110,14 +128,18 @@ public:
   // === Type ===
   // === Type ===
 
 
   uint32_t getVoidType();
   uint32_t getVoidType();
+  uint32_t getInt32Type();
   uint32_t getFloatType();
   uint32_t getFloatType();
-  uint32_t getVec2Type(uint32_t elemType);
-  uint32_t getVec3Type(uint32_t elemType);
-  uint32_t getVec4Type(uint32_t elemType);
+  uint32_t getVecType(uint32_t elemType, uint32_t elemCount);
   uint32_t getPointerType(uint32_t pointeeType, spv::StorageClass);
   uint32_t getPointerType(uint32_t pointeeType, spv::StorageClass);
+  uint32_t getStructType(llvm::ArrayRef<uint32_t> fieldTypes);
   uint32_t getFunctionType(uint32_t returnType,
   uint32_t getFunctionType(uint32_t returnType,
                            const std::vector<uint32_t> &paramTypes);
                            const std::vector<uint32_t> &paramTypes);
 
 
+  // === Constant ===
+
+  uint32_t getInt32Value(uint32_t value);
+
 private:
 private:
   /// \brief Map from basic blocks' <label-id> to their structured
   /// \brief Map from basic blocks' <label-id> to their structured
   /// representation.
   /// representation.

+ 23 - 5
tools/clang/include/clang/SPIRV/Structure.h

@@ -18,6 +18,7 @@
 #ifndef LLVM_CLANG_SPIRV_STRUCTURE_H
 #ifndef LLVM_CLANG_SPIRV_STRUCTURE_H
 #define LLVM_CLANG_SPIRV_STRUCTURE_H
 #define LLVM_CLANG_SPIRV_STRUCTURE_H
 
 
+#include <deque>
 #include <memory>
 #include <memory>
 #include <string>
 #include <string>
 #include <vector>
 #include <vector>
@@ -63,15 +64,18 @@ public:
   /// state.
   /// state.
   void take(InstBuilder *builder);
   void take(InstBuilder *builder);
 
 
-  /// \brief Add an instruction to this basic block.
-  inline void addInstruction(Instruction &&);
+  /// \brief Appends an instruction to this basic block.
+  inline void appendInstruction(Instruction &&);
+
+  /// \brief Preprends an instruction to this basic block.
+  inline void prependInstruction(Instruction &&);
 
 
   /// \brief Returns true if this basic block is terminated.
   /// \brief Returns true if this basic block is terminated.
   bool isTerminated() const;
   bool isTerminated() const;
 
 
 private:
 private:
   uint32_t labelId; ///< The label id for this basic block. Zero means invalid.
   uint32_t labelId; ///< The label id for this basic block. Zero means invalid.
-  std::vector<Instruction> instructions;
+  std::deque<Instruction> instructions;
 };
 };
 
 
 /// \brief The class representing a SPIR-V function.
 /// \brief The class representing a SPIR-V function.
@@ -99,12 +103,15 @@ public:
   void clear();
   void clear();
 
 
   /// \brief Serializes this function and feeds it to the comsumer in the given
   /// \brief Serializes this function and feeds it to the comsumer in the given
-  /// InstBuilder. After this call, this function will be in an invalid state.
+  /// InstBuilder. After this call, this function will be in an empty state.
   void take(InstBuilder *builder);
   void take(InstBuilder *builder);
 
 
   /// \brief Adds a parameter to this function.
   /// \brief Adds a parameter to this function.
   inline void addParameter(uint32_t paramResultType, uint32_t paramResultId);
   inline void addParameter(uint32_t paramResultType, uint32_t paramResultId);
 
 
+  /// \brief Adds a local variable to this function.
+  inline void addVariable(uint32_t varResultType, uint32_t varResultId);
+
   /// \brief Adds a basic block to this function.
   /// \brief Adds a basic block to this function.
   inline void addBasicBlock(std::unique_ptr<BasicBlock> block);
   inline void addBasicBlock(std::unique_ptr<BasicBlock> block);
 
 
@@ -113,8 +120,11 @@ private:
   uint32_t resultId;
   uint32_t resultId;
   spv::FunctionControlMask funcControl;
   spv::FunctionControlMask funcControl;
   uint32_t funcType;
   uint32_t funcType;
+
   /// Parameter <result-type> and <result-id> pairs.
   /// Parameter <result-type> and <result-id> pairs.
   std::vector<std::pair<uint32_t, uint32_t>> parameters;
   std::vector<std::pair<uint32_t, uint32_t>> parameters;
+  /// Local variable <result-type> and <result-id> pairs.
+  std::vector<std::pair<uint32_t, uint32_t>> variables;
   std::vector<std::unique_ptr<BasicBlock>> blocks;
   std::vector<std::unique_ptr<BasicBlock>> blocks;
 };
 };
 
 
@@ -268,10 +278,14 @@ void BasicBlock::clear() {
   instructions.clear();
   instructions.clear();
 }
 }
 
 
-void BasicBlock::addInstruction(Instruction &&inst) {
+void BasicBlock::appendInstruction(Instruction &&inst) {
   instructions.push_back(std::move(inst));
   instructions.push_back(std::move(inst));
 }
 }
 
 
+void BasicBlock::prependInstruction(Instruction &&inst) {
+  instructions.push_front(std::move(inst));
+}
+
 Function::Function()
 Function::Function()
     : resultType(0), resultId(0),
     : resultType(0), resultId(0),
       funcControl(spv::FunctionControlMask::MaskNone), funcType(0) {}
       funcControl(spv::FunctionControlMask::MaskNone), funcType(0) {}
@@ -290,6 +304,10 @@ void Function::addParameter(uint32_t rType, uint32_t rId) {
   parameters.emplace_back(rType, rId);
   parameters.emplace_back(rType, rId);
 }
 }
 
 
+void Function::addVariable(uint32_t varType, uint32_t varId) {
+  variables.emplace_back(varType, varId);
+}
+
 void Function::addBasicBlock(std::unique_ptr<BasicBlock> block) {
 void Function::addBasicBlock(std::unique_ptr<BasicBlock> block) {
   blocks.push_back(std::move(block));
   blocks.push_back(std::move(block));
 }
 }

+ 2 - 1
tools/clang/include/clang/SPIRV/Type.h

@@ -15,6 +15,7 @@
 
 
 #include "spirv/1.0/spirv.hpp11"
 #include "spirv/1.0/spirv.hpp11"
 #include "clang/SPIRV/Decoration.h"
 #include "clang/SPIRV/Decoration.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/Optional.h"
 #include "llvm/ADT/Optional.h"
 
 
 namespace clang {
 namespace clang {
@@ -92,7 +93,7 @@ public:
                                      uint32_t component_type_id,
                                      uint32_t component_type_id,
                                      DecorationSet decs = {});
                                      DecorationSet decs = {});
   static const Type *getStruct(SPIRVContext &ctx,
   static const Type *getStruct(SPIRVContext &ctx,
-                               std::initializer_list<uint32_t> members,
+                               llvm::ArrayRef<uint32_t> members,
                                DecorationSet d = {});
                                DecorationSet d = {});
   static const Type *getOpaque(SPIRVContext &ctx, std::string name,
   static const Type *getOpaque(SPIRVContext &ctx, std::string name,
                                DecorationSet decs = {});
                                DecorationSet decs = {});

+ 394 - 225
tools/clang/lib/SPIRV/EmitSPIRVAction.cpp

@@ -8,12 +8,12 @@
 //===----------------------------------------------------------------------===//
 //===----------------------------------------------------------------------===//
 
 
 #include "clang/SPIRV/EmitSPIRVAction.h"
 #include "clang/SPIRV/EmitSPIRVAction.h"
+
 #include "clang/AST/AST.h"
 #include "clang/AST/AST.h"
 #include "clang/AST/ASTConsumer.h"
 #include "clang/AST/ASTConsumer.h"
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/HlslTypes.h"
 #include "clang/AST/HlslTypes.h"
 #include "clang/AST/RecordLayout.h"
 #include "clang/AST/RecordLayout.h"
-#include "clang/AST/RecursiveASTVisitor.h"
 #include "clang/Basic/Diagnostic.h"
 #include "clang/Basic/Diagnostic.h"
 #include "clang/Basic/FileManager.h"
 #include "clang/Basic/FileManager.h"
 #include "clang/Basic/SourceManager.h"
 #include "clang/Basic/SourceManager.h"
@@ -25,41 +25,79 @@
 namespace clang {
 namespace clang {
 namespace {
 namespace {
 
 
-/// \brief Generates the corresponding SPIR-V type for the given Clang frontend
-/// type and returns the <result-id>.
+/// The class responsible to translate Clang frontend types into SPIR-V type
+/// instructions.
 ///
 ///
-/// The translation is recursive; all the types that the target type depends on
-/// will be generated.
-uint32_t translateType(QualType type, spirv::ModuleBuilder &theBuilder) {
-  // In AST, vector types are TypedefType of TemplateSpecializationType,
-  // which is nested deeply. So we do fast track check here.
-  const auto symbol = type.getAsString();
-  if (symbol == "float4") {
-    const uint32_t floatType = theBuilder.getFloatType();
-    return theBuilder.getVec4Type(floatType);
-  } else if (symbol == "float3") {
-    const uint32_t floatType = theBuilder.getFloatType();
-    return theBuilder.getVec3Type(floatType);
-  } else if (symbol == "float2") {
-    const uint32_t floatType = theBuilder.getFloatType();
-    return theBuilder.getVec2Type(floatType);
-  } else if (auto *builtinType = dyn_cast<BuiltinType>(type.getTypePtr())) {
-    switch (builtinType->getKind()) {
-    case BuiltinType::Void:
-      return theBuilder.getVoidType();
-    case BuiltinType::Float:
-      return theBuilder.getFloatType();
-    default:
-      // TODO: handle other primitive types
-      assert(false && "unhandled builtin type");
-      break;
+/// SPIR-V type instructions generated during translation will be emitted to
+/// the SPIR-V module builder passed into the constructor.
+/// Warnings and errors during the translation will be reported to the
+/// DiagnosticEngine passed into the constructor.
+class TypeTranslator {
+public:
+  TypeTranslator(spirv::ModuleBuilder &builder, DiagnosticsEngine &diag)
+      : theBuilder(builder), diags(diag) {}
+
+  /// \brief Generates the corresponding SPIR-V type for the given Clang
+  /// frontend type and returns the type's <result-id>. On failure, reports
+  /// the error and returns 0.
+  ///
+  /// The translation is recursive; all the types that the target type depends
+  /// on will be generated.
+  uint32_t translateType(QualType type) {
+    const auto *typePtr = type.getTypePtr();
+
+    // Primitive types
+    if (const auto *builtinType = dyn_cast<BuiltinType>(typePtr)) {
+      switch (builtinType->getKind()) {
+      case BuiltinType::Void:
+        return theBuilder.getVoidType();
+      case BuiltinType::Float:
+        return theBuilder.getFloatType();
+      default:
+        emitError("Primitive type '%0' is not supported yet.")
+            << builtinType->getTypeClassName();
+        return 0;
+      }
+    }
+
+    // In AST, vector types are TypedefType of TemplateSpecializationType.
+    // We handle them via HLSL type inspection functions.
+    if (hlsl::IsHLSLVecType(type)) {
+      const auto elemType = hlsl::GetHLSLVecElementType(type);
+      const auto elemCount = hlsl::GetHLSLVecSize(type);
+      return theBuilder.getVecType(translateType(elemType), elemCount);
+    }
+
+    // Struct type
+    if (const auto *structType = dyn_cast<RecordType>(typePtr)) {
+      const auto *decl = structType->getDecl();
+
+      // Collect all fields' types.
+      std::vector<uint32_t> fieldTypes;
+      for (const auto *field : decl->fields()) {
+        fieldTypes.push_back(translateType(field->getType()));
+      }
+
+      return theBuilder.getStructType(fieldTypes);
     }
     }
-  } else {
-    // TODO: handle other types
-    assert(false && "unhandled clang type");
+
+    emitError("Type '%0' is not supported yet.") << type->getTypeClassName();
+    return 0;
   }
   }
-  return 0;
-}
+
+private:
+  /// \brief Wrapper method to create an error message and report it
+  /// in the diagnostic engine associated with this consumer.
+  template <unsigned N> DiagnosticBuilder emitError(const char (&message)[N]) {
+    const auto diagId =
+        diags.getCustomDiagID(clang::DiagnosticsEngine::Error, message);
+    return diags.Report(diagId);
+  }
+
+private:
+  spirv::ModuleBuilder &theBuilder;
+  DiagnosticsEngine &diags;
+};
 
 
 /// \brief The class containing mappings from Clang frontend Decls to their
 /// \brief The class containing mappings from Clang frontend Decls to their
 /// corresponding SPIR-V <result-id>s.
 /// corresponding SPIR-V <result-id>s.
@@ -70,84 +108,40 @@ uint32_t translateType(QualType type, spirv::ModuleBuilder &theBuilder) {
 /// will be used to generate all SPIR-V instructions required.
 /// will be used to generate all SPIR-V instructions required.
 ///
 ///
 /// This class acts as a middle layer to handle the mapping between HLSL
 /// This class acts as a middle layer to handle the mapping between HLSL
-/// semantics and Vulkan interface (stage builtin/input/output) variables.
-/// Such mapping is required because of the semantic differences between DirectX
-/// and Vulkan and the essence of HLSL as the front-end language for DirectX.
+/// semantics and Vulkan stage (builtin/input/output) variables. Such mapping
+/// is required because of the semantic differences between DirectX and
+/// Vulkan and the essence of HLSL as the front-end language for DirectX.
 /// A normal variable attached with some semantic will be translated into a
 /// A normal variable attached with some semantic will be translated into a
-/// single interface variables if it is of non-struct type. If it is of struct
+/// single stage variables if it is of non-struct type. If it is of struct
 /// type, the fields with attached semantics will need to be translated into
 /// type, the fields with attached semantics will need to be translated into
-/// interface variables per Vulkan's requirements.
+/// stage variables per Vulkan's requirements.
 ///
 ///
-/// In the following class, we call a Decl or symbol as *remapped* when it is
-/// translated into an interface variable; otherwise, we call it as *normal*.
+/// In the following class, we call a Decl as *remapped* when it is translated
+/// into a stage variable; otherwise, we call it as *normal*. Remapped decls
+/// include:
+/// * FunctionDecl if the return value is attached with a semantic
+/// * ParmVarDecl if the parameter is attached with a semantic
+/// * FieldDecl if the field is attached with a semantic.
 class DeclResultIdMapper {
 class DeclResultIdMapper {
 public:
 public:
-  DeclResultIdMapper(spirv::ModuleBuilder *builder) : theBuilder(*builder) {}
+  DeclResultIdMapper(spv::ExecutionModel stage, spirv::ModuleBuilder &builder,
+                     DiagnosticsEngine &diag)
+      : shaderStage(stage), theBuilder(builder), typeTranslator(builder, diag) {
+  }
 
 
-  /// \brief Defines a function return value in this mapper and returns the
-  /// <result-id> for the final return type.
-  ///
-  /// The final return type is the "residual" type after "stripping" all
-  /// subtypes with attached semantics. For exmaple, stripping "float4 :
-  /// SV_Target" will result in "void", and stripping "struct { float4 :
-  /// SV_Target, float4 }" will result in "struct { float4 }".
-  ///
-  /// Proper SPIR-V instructions will be generated to create the corresponding
-  /// interface variable if stripping happens.
-  uint32_t defineFnReturn(FunctionDecl *funcDecl) {
+  /// \brief Creates the stage variables by parsing the semantics attached to
+  /// the given function's return value.
+  void createStageVarFromFnReturn(FunctionDecl *funcDecl) {
     // SemanticDecl for the return value is attached to the FunctionDecl.
     // SemanticDecl for the return value is attached to the FunctionDecl.
-    const auto sk = getInterfaceVarSemanticAndKind(funcDecl);
-    if (sk.second != InterfaceVariableKind::None) {
-      // Found return value with semantic attached. This means we need to map
-      // the return value to a single interface variable.
-      const uint32_t retTypeId =
-          translateType(funcDecl->getReturnType(), theBuilder);
-      // TODO: Change to the correct interface variable kind here.
-      const uint32_t varId = theBuilder.addStageIOVariable(
-          retTypeId, spv::StorageClass::Output, llvm::None);
-
-      stageOutputs.push_back(std::make_pair(varId, sk.first));
-      remappedDecls[funcDecl] = varId;
-      interfaceVars.insert(varId);
-
-      return theBuilder.getVoidType();
-    } else {
-      // TODO: We need to handle struct return types here.
-      return translateType(funcDecl->getReturnType(), theBuilder);
-    }
+    createStageVariables(funcDecl, false);
   }
   }
 
 
-  /// \brief Defines a function parameter in this mapper and returns the
-  /// <result-id> for the final parameter type. Returns 0 if the final type
-  /// is void.
-  ///
-  /// The final parameter type is the "residual" type after "stripping" all
-  /// subtypes will attached semantics. For exmaple, stripping "float4 :
-  /// SV_Target" will result in "void", and stripping "struct { float4 :
-  /// SV_Target, float4 }" will result in "struct { float4 }".
-  ///
-  /// Proper SPIR-V instructions will be generated to create the corresponding
-  /// interface variable if stripping happens.
-  uint32_t defineFnParam(ParmVarDecl *paramDecl) {
-    const auto sk = getInterfaceVarSemanticAndKind(paramDecl);
-    if (sk.second != InterfaceVariableKind::None) {
-      // Found parameter with semantic attached. This means we need to map the
-      // parameter to a single interface variable.
-      const uint32_t paramTypeId =
-          translateType(paramDecl->getType(), theBuilder);
-      // TODO: Change to the correct interface variable kind here.
-      const uint32_t varId = theBuilder.addStageIOVariable(
-          paramTypeId, spv::StorageClass::Input, llvm::None);
-
-      stageInputs.push_back(std::make_pair(varId, sk.first));
-      remappedDecls[paramDecl] = varId;
-      interfaceVars.insert(varId);
-
-      return 0;
-    } else {
-      // TODO: We need to handle struct parameter types here.
-      return translateType(paramDecl->getType(), theBuilder);
-    }
+  /// \brief Creates the stage variables by parsing the semantics attached to
+  /// the given function's parameter.
+  void createStageVarFromFnParam(ParmVarDecl *paramDecl) {
+    // TODO: We cannot treat all parameters as stage inputs because of
+    // out/input modifiers.
+    createStageVariables(paramDecl, true);
   }
   }
 
 
   /// \brief Registers a Decl's <result-id> without generating any SPIR-V
   /// \brief Registers a Decl's <result-id> without generating any SPIR-V
@@ -156,9 +150,9 @@ public:
     normalDecls[symbol] = resultId;
     normalDecls[symbol] = resultId;
   }
   }
 
 
-  /// \brief Returns true if the given <result-id> is for an interface variable.
-  bool isInterfaceVariable(uint32_t varId) const {
-    return interfaceVars.count(varId) != 0;
+  /// \brief Returns true if the given <result-id> is for a stage variable.
+  bool isStageVariable(uint32_t varId) const {
+    return stageVars.count(varId) != 0;
   }
   }
 
 
   /// \brief Returns the <result-id> for the given Decl.
   /// \brief Returns the <result-id> for the given Decl.
@@ -190,18 +184,22 @@ public:
     return 0;
     return 0;
   }
   }
 
 
-  /// \brief Returns all defined stage input and ouput variables in this mapper.
-  std::vector<uint32_t> collectStageIOVariables() {
-    std::vector<uint32_t> stageIOVars;
+  /// \brief Returns all defined stage (builtin/input/ouput) variables in this
+  /// mapper.
+  std::vector<uint32_t> collectStageVariables() const {
+    std::vector<uint32_t> stageVars;
 
 
+    for (const auto &builtin : stageBuiltins) {
+      stageVars.push_back(builtin.first);
+    }
     for (const auto &input : stageInputs) {
     for (const auto &input : stageInputs) {
-      stageIOVars.push_back(input.first);
+      stageVars.push_back(input.first);
     }
     }
     for (const auto &output : stageOutputs) {
     for (const auto &output : stageOutputs) {
-      stageIOVars.push_back(output.first);
+      stageVars.push_back(output.first);
     }
     }
 
 
-    return stageIOVars;
+    return stageVars;
   }
   }
 
 
   /// \brief Decorates all stage input and output variables with proper
   /// \brief Decorates all stage input and output variables with proper
@@ -224,87 +222,151 @@ public:
   }
   }
 
 
 private:
 private:
-  /// \brief Interface variable kind.
+  /// \brief Stage variable kind.
   ///
   ///
-  /// By interface variable, I mean all stage builtin, input, and output
-  /// variables. They participate in interface matching in Vulkan pipelines.
-  enum class InterfaceVariableKind {
-    None,   ///< Not an interface variable
-    Input,  ///< Stage input variable
-    Output, ///< Stage output variable
-    IO, ///< Interface variable that can act as both stage input or stage output
-    // TODO: builtins
+  /// Stage variables include builtin, input, and output variables.
+  /// They participate in interface matching in Vulkan pipelines.
+  enum class StageVarKind {
+    None,
+    Arbitary,
+    Position,
+    Color,
+    Target,
+    // TODO: other possible kinds
   };
   };
 
 
-  using InterfaceVarIdSemanticPair = std::pair<uint32_t, llvm::StringRef>;
+  using StageVarIdSemanticPair = std::pair<uint32_t, std::string>;
+
+  /// Returns the type of the given decl. If the given decl is a FunctionDecl,
+  /// returns its result type.
+  QualType getFnParamOrRetType(const DeclaratorDecl *decl) const {
+    if (const auto *funcDecl = dyn_cast<FunctionDecl>(decl)) {
+      return funcDecl->getReturnType();
+    }
+    return decl->getType();
+  }
+
+  /// Creates all the stage variables mapped from semantics on the given decl.
+  ///
+  /// Assumes the decl has semantic attached to itself or to its fields.
+  void createStageVariables(const DeclaratorDecl *decl, bool actAsInput) {
+    QualType type = getFnParamOrRetType(decl);
+
+    if (type->isVoidType()) {
+      // No stage variables will be created for void type.
+      return;
+    }
+
+    const std::string semantic = getStageVarSemantic(decl);
+    if (!semantic.empty()) {
+      // Found semantic attached directly to this Decl. This means we need to
+      // map this decl to a single stage variable.
+      const uint32_t typeId = typeTranslator.translateType(type);
+      const auto kind = getStageVarKind(semantic);
+
+      if (actAsInput) {
+        // Stage (builtin) input variable cases
+        const uint32_t varId =
+            theBuilder.addStageIOVariable(typeId, spv::StorageClass::Input);
+
+        stageInputs.push_back(std::make_pair(varId, semantic));
+        remappedDecls[decl] = varId;
+        stageVars.insert(varId);
+      } else {
+        // Handle output builtin variables first
+        if (shaderStage == spv::ExecutionModel::Vertex &&
+            kind == StageVarKind::Position) {
+          const uint32_t varId = theBuilder.addStageBuiltinVariable(
+              typeId, spv::BuiltIn::Position);
+
+          stageBuiltins.push_back(std::make_pair(varId, semantic));
+          remappedDecls[decl] = varId;
+          stageVars.insert(varId);
+        } else {
+          // The rest are normal stage output variables
+          const uint32_t varId =
+              theBuilder.addStageIOVariable(typeId, spv::StorageClass::Output);
+
+          stageOutputs.push_back(std::make_pair(varId, semantic));
+          remappedDecls[decl] = varId;
+          stageVars.insert(varId);
+        }
+      }
+    } else {
+      // If the decl itself doesn't have semantic, it should be a struct having
+      // all its fields with semantics.
+      assert(type->isStructureType() &&
+             "found non-struct decls without semantics");
 
 
-  /// \brief Returns the interface variable's semantic and kind for the given
-  /// Decl.
-  std::pair<llvm::StringRef, InterfaceVariableKind>
-  getInterfaceVarSemanticAndKind(NamedDecl *decl) const {
+      const auto *structDecl = cast<RecordType>(type.getTypePtr())->getDecl();
+
+      // Recursively handle all the fields.
+      for (const auto *field : structDecl->fields()) {
+        createStageVariables(field, actAsInput);
+      }
+    }
+  }
+
+  /// \brief Returns the stage variable's kind for the given semantic.
+  StageVarKind getStageVarKind(llvm::StringRef semantic) const {
+    return llvm::StringSwitch<StageVarKind>(semantic)
+        .Case("", StageVarKind::None)
+        .StartsWith("COLOR", StageVarKind::Color)
+        .StartsWith("POSITION", StageVarKind::Position)
+        .StartsWith("SV_POSITION", StageVarKind::Position)
+        .StartsWith("SV_TARGET", StageVarKind::Target)
+        .Default(StageVarKind::Arbitary);
+  }
+
+  /// \brief Returns the stage variable's semantic for the given Decl.
+  std::string getStageVarSemantic(const NamedDecl *decl) const {
     for (auto *annotation : decl->getUnusualAnnotations()) {
     for (auto *annotation : decl->getUnusualAnnotations()) {
       if (auto *semantic = dyn_cast<hlsl::SemanticDecl>(annotation)) {
       if (auto *semantic = dyn_cast<hlsl::SemanticDecl>(annotation)) {
-        const llvm::StringRef name = semantic->SemanticName;
-        // TODO: We should check the semantic name ends with a number.
-        if (name.startswith("SV_TARGET")) {
-          return std::make_pair(name, InterfaceVariableKind::Output);
-        }
-        if (name.startswith("COLOR")) {
-          return std::make_pair(name, InterfaceVariableKind::IO);
-        }
+        return semantic->SemanticName.upper();
       }
       }
     }
     }
-    return std::make_pair("", InterfaceVariableKind::None);
+    return "";
   }
   }
 
 
 private:
 private:
+  const spv::ExecutionModel shaderStage;
   spirv::ModuleBuilder &theBuilder;
   spirv::ModuleBuilder &theBuilder;
+  TypeTranslator typeTranslator;
 
 
   /// Mapping of all remapped decls to their <result-id>s.
   /// Mapping of all remapped decls to their <result-id>s.
   llvm::DenseMap<const NamedDecl *, uint32_t> remappedDecls;
   llvm::DenseMap<const NamedDecl *, uint32_t> remappedDecls;
   /// Mapping of all normal decls to their <result-id>s.
   /// Mapping of all normal decls to their <result-id>s.
   llvm::DenseMap<const NamedDecl *, uint32_t> normalDecls;
   llvm::DenseMap<const NamedDecl *, uint32_t> normalDecls;
-  /// <result-id>s of all defined interface variables.
+  /// <result-id>s of all defined stage variables.
   ///
   ///
   /// We need to keep a separate list here to avoid looping through the
   /// We need to keep a separate list here to avoid looping through the
-  /// remappedDecls to find whether an <result-id> is for interface variable.
-  llvm::SmallSet<uint32_t, 16> interfaceVars;
+  /// remappedDecls to find whether an <result-id> is for a stage variable.
+  llvm::SmallSet<uint32_t, 16> stageVars;
 
 
   /// Stage input/oupt/builtin variables and their kinds.
   /// Stage input/oupt/builtin variables and their kinds.
   ///
   ///
   /// We need to keep a separate list here in order to sort them at the end
   /// We need to keep a separate list here in order to sort them at the end
   /// of the module building.
   /// of the module building.
-  llvm::SmallVector<InterfaceVarIdSemanticPair, 8> stageInputs;
-  llvm::SmallVector<InterfaceVarIdSemanticPair, 8> stageOutputs;
-  llvm::SmallVector<InterfaceVarIdSemanticPair, 8> stageBuiltins;
+  llvm::SmallVector<StageVarIdSemanticPair, 8> stageInputs;
+  llvm::SmallVector<StageVarIdSemanticPair, 8> stageOutputs;
+  llvm::SmallVector<StageVarIdSemanticPair, 8> stageBuiltins;
 };
 };
 
 
+/// SPIR-V emitter class. It consumes the HLSL AST and emits SPIR-V words.
+///
+/// This class only overrides the HandleTranslationUnit() method; Traversing
+/// through the AST is done manually instead of using ASTConsumer's harness.
 class SPIRVEmitter : public ASTConsumer {
 class SPIRVEmitter : public ASTConsumer {
 public:
 public:
   explicit SPIRVEmitter(CompilerInstance &ci)
   explicit SPIRVEmitter(CompilerInstance &ci)
-      : theCompilerInstance(ci), diags(ci.getDiagnostics()), theContext(),
-        theBuilder(&theContext), declIdMapper(&theBuilder),
+      : theCompilerInstance(ci), diags(ci.getDiagnostics()),
         entryFunctionName(ci.getCodeGenOpts().HLSLEntryFunction),
         entryFunctionName(ci.getCodeGenOpts().HLSLEntryFunction),
         shaderStage(getSpirvShaderStageFromHlslProfile(
         shaderStage(getSpirvShaderStageFromHlslProfile(
             ci.getCodeGenOpts().HLSLProfile.c_str())),
             ci.getCodeGenOpts().HLSLProfile.c_str())),
-        entryFunctionId(0), curFunction(nullptr) {}
-
-  /// \brief Wrapper method to create an error message and report it
-  /// in the diagnostic engine associated with this consumer
-  template <unsigned N> DiagnosticBuilder emitError(const char (&message)[N]) {
-    const auto diagId =
-        diags.getCustomDiagID(clang::DiagnosticsEngine::Error, message);
-    return diags.Report(diagId);
-  }
-
-  /// \brief Wrapper method to create a warning message and report it
-  /// in the diagnostic engine associated with this consumer
-  template <unsigned N>
-  DiagnosticBuilder emitWarning(const char (&message)[N]) {
-    const auto diagId =
-        diags.getCustomDiagID(clang::DiagnosticsEngine::Warning, message);
-    return diags.Report(diagId);
-  }
+        theContext(), theBuilder(&theContext),
+        declIdMapper(shaderStage, theBuilder, diags),
+        typeTranslator(theBuilder, diags), entryFunctionId(0),
+        curFunction(nullptr) {}
 
 
   spv::ExecutionModel getSpirvShaderStageFromHlslProfile(const char *profile) {
   spv::ExecutionModel getSpirvShaderStageFromHlslProfile(const char *profile) {
     assert(profile && "nullptr passed as HLSL profile.");
     assert(profile && "nullptr passed as HLSL profile.");
@@ -379,13 +441,27 @@ public:
 
 
     TranslationUnitDecl *tu = context.getTranslationUnitDecl();
     TranslationUnitDecl *tu = context.getTranslationUnitDecl();
 
 
-    // Process all top level Decls.
+    // A queue of functions we need to translate.
+    std::deque<FunctionDecl *> workQueue;
+
+    // The entry function is the seed of the queue.
     for (auto *decl : tu->decls()) {
     for (auto *decl : tu->decls()) {
-      doDecl(decl);
+      if (auto *funcDecl = dyn_cast<FunctionDecl>(decl)) {
+        if (funcDecl->getName() == entryFunctionName) {
+          workQueue.push_back(funcDecl);
+        }
+      }
+    }
+    // TODO: enlarge the queue upon seeing a function call.
+
+    // Translate all functions reachable from the entry function.
+    while (!workQueue.empty()) {
+      doFunctionDecl(workQueue.front());
+      workQueue.pop_front();
     }
     }
 
 
     theBuilder.addEntryPoint(shaderStage, entryFunctionId, entryFunctionName,
     theBuilder.addEntryPoint(shaderStage, entryFunctionId, entryFunctionName,
-                             declIdMapper.collectStageIOVariables());
+                             declIdMapper.collectStageVariables());
 
 
     AddExecutionModeForEntryPoint(shaderStage, entryFunctionId);
     AddExecutionModeForEntryPoint(shaderStage, entryFunctionId);
 
 
@@ -399,92 +475,150 @@ public:
   }
   }
 
 
   void doDecl(Decl *decl) {
   void doDecl(Decl *decl) {
-    if (auto *funcDecl = dyn_cast<FunctionDecl>(decl)) {
-      doFunctionDecl(funcDecl);
+    if (auto *varDecl = dyn_cast<VarDecl>(decl)) {
+      doVarDecl(varDecl);
     } else {
     } else {
-      emitWarning("Translation is not implemented for this decl type: %0")
-          << std::string(decl->getDeclKindName());
       // TODO: Implement handling of other Decl types.
       // TODO: Implement handling of other Decl types.
+      emitWarning("Decl type '%0' is not supported yet.")
+          << std::string(decl->getDeclKindName());
     }
     }
   }
   }
 
 
   void doFunctionDecl(FunctionDecl *decl) {
   void doFunctionDecl(FunctionDecl *decl) {
     curFunction = decl;
     curFunction = decl;
 
 
-    uint32_t retType = declIdMapper.defineFnReturn(decl);
-
-    // Process function parameters. We need to strip parameters mapping to stage
-    // builtin/input/output variables and use what's left in the function type.
-    std::vector<uint32_t> residualParamTypes;
-    std::vector<ParmVarDecl *> residualParams;
-
-    for (auto *param : decl->params()) {
-      // Get the "residual" parameter type. If nothing left after stripping,
-      // zero will be returned.
-      const uint32_t paramType = declIdMapper.defineFnParam(param);
-      if (paramType) {
-        residualParamTypes.push_back(paramType);
-        residualParams.push_back(param);
-      }
-    }
-
-    const uint32_t funcType =
-        theBuilder.getFunctionType(retType, residualParamTypes);
-    const std::string funcName = decl->getNameInfo().getAsString();
-    const uint32_t funcId =
-        theBuilder.beginFunction(funcType, retType, funcName);
-
-    // Register all the "residual" parameters into the mapper.
-    for (uint32_t i = 0; i < residualParams.size(); ++i) {
-      declIdMapper.registerDeclResultId(
-          residualParams[i], theBuilder.addFnParameter(residualParamTypes[i]));
-    }
-
-    if (decl->hasBody()) {
-      const uint32_t entryLabel = theBuilder.createBasicBlock();
-      theBuilder.setInsertPoint(entryLabel);
-
-      // Process all statments in the body.
-      for (Stmt *stmt : cast<CompoundStmt>(decl->getBody())->body())
-        doStmt(stmt);
-
-      // We have processed all Stmts in this function and now in the last basic
-      // block. Make sure we have OpReturn if this is a void(...) function.
-      if (retType == theBuilder.getVoidType() &&
-          !theBuilder.isCurrentBasicBlockTerminated()) {
-        theBuilder.createReturn();
+    const llvm::StringRef funcName = decl->getName();
+
+    if (funcName == entryFunctionName) {
+      // First create stage variables for the entry point.
+      declIdMapper.createStageVarFromFnReturn(decl);
+      for (auto *param : decl->params())
+        declIdMapper.createStageVarFromFnParam(param);
+
+      // Construct the function signature.
+      const uint32_t voidType = theBuilder.getVoidType();
+      const uint32_t funcType = theBuilder.getFunctionType(voidType, {});
+      const uint32_t funcId =
+          theBuilder.beginFunction(funcType, voidType, funcName);
+
+      if (decl->hasBody()) {
+        // The entry basic block.
+        const uint32_t entryLabel = theBuilder.createBasicBlock();
+        theBuilder.setInsertPoint(entryLabel);
+
+        // Process all statments in the body.
+        for (Stmt *stmt : cast<CompoundStmt>(decl->getBody())->body())
+          doStmt(stmt);
+
+        // We have processed all Stmts in this function and now in the last
+        // basic block. Make sure we have OpReturn if missing.
+        if (!theBuilder.isCurrentBasicBlockTerminated()) {
+          theBuilder.createReturn();
+        }
       }
       }
 
 
       theBuilder.endFunction();
       theBuilder.endFunction();
-    }
 
 
-    // Record the entry function's <result-id>.
-    if (entryFunctionName == funcName) {
+      // Record the entry function's <result-id>.
       entryFunctionId = funcId;
       entryFunctionId = funcId;
+    } else {
+      emitError("Non-entry functions are not supported yet.");
     }
     }
 
 
     curFunction = nullptr;
     curFunction = nullptr;
   }
   }
 
 
+  void doVarDecl(VarDecl *decl) {
+    if (decl->isLocalVarDecl()) {
+      const uint32_t ptrType = theBuilder.getPointerType(
+          typeTranslator.translateType(decl->getType()),
+          spv::StorageClass::Function);
+      const uint32_t varId = theBuilder.addFnVariable(ptrType);
+      declIdMapper.registerDeclResultId(decl, varId);
+    } else {
+      // TODO: handle global variables
+      emitError("Global variables are not supported yet.");
+    }
+  }
+
   void doStmt(Stmt *stmt) {
   void doStmt(Stmt *stmt) {
     if (auto *retStmt = dyn_cast<ReturnStmt>(stmt)) {
     if (auto *retStmt = dyn_cast<ReturnStmt>(stmt)) {
       doReturnStmt(retStmt);
       doReturnStmt(retStmt);
+    } else if (auto *declStmt = dyn_cast<DeclStmt>(stmt)) {
+      for (auto *decl : declStmt->decls()) {
+        doDecl(decl);
+      }
+    } else if (auto *binOp = dyn_cast<BinaryOperator>(stmt)) {
+      const auto opcode = binOp->getOpcode();
+      const uint32_t lhs = doExpr(binOp->getLHS());
+      const uint32_t rhs = doExpr(binOp->getRHS());
+
+      doBinaryOperator(opcode, lhs, rhs);
+    } else {
+      emitError("Stmt '%0' is not supported yet.") << stmt->getStmtClassName();
     }
     }
     // TODO: handle other statements
     // TODO: handle other statements
   }
   }
 
 
+  void doBinaryOperator(BinaryOperatorKind opcode, uint32_t lhs, uint32_t rhs) {
+    if (opcode == BO_Assign) {
+      theBuilder.createStore(lhs, rhs);
+    } else {
+      emitError("BinaryOperator '%0' is not supported yet.") << opcode;
+    }
+  }
+
   void doReturnStmt(ReturnStmt *stmt) {
   void doReturnStmt(ReturnStmt *stmt) {
+    // First get the <result-id> of the value we want to return.
     const uint32_t retValue = doExpr(stmt->getRetValue());
     const uint32_t retValue = doExpr(stmt->getRetValue());
-    const uint32_t interfaceVarId =
+
+    if (curFunction->getName() != entryFunctionName) {
+      theBuilder.createReturnValue(retValue);
+      return;
+    }
+
+    // SPIR-V requires the signature of entry functions to be void(), while
+    // in HLSL we can have non-void parameter and return types for entry points.
+    // So we should treat the ReturnStmt in entry functions specially.
+    //
+    // We need to walk through the return type, and for each subtype attached
+    // with semantics, write out the value to the corresponding stage variable
+    // mapped to the semantic.
+    const uint32_t stageVarId =
         declIdMapper.getRemappedDeclResultId(curFunction);
         declIdMapper.getRemappedDeclResultId(curFunction);
 
 
-    if (interfaceVarId) {
-      // The return value is mapped to an interface variable. We need to store
-      // the value into the interface variable instead.
-      theBuilder.createStore(interfaceVarId, retValue);
+    if (stageVarId) {
+      // The return value is mapped to a single stage variable. We just need
+      // to store the value into the stage variable instead.
+      theBuilder.createStore(stageVarId, retValue);
       theBuilder.createReturn();
       theBuilder.createReturn();
+      return;
+    }
+
+    QualType retType = stmt->getRetValue()->getType();
+
+    if (const auto *structType = retType->getAsStructureType()) {
+      // We are trying to return a value of struct type. Go through all fields.
+      uint32_t fieldIndex = 0;
+      for (const auto *field : structType->getDecl()->fields()) {
+        // Load the value from the current field.
+        const uint32_t valueType =
+            typeTranslator.translateType(field->getType());
+        // TODO: We may need to change the storage class accordingly.
+        const uint32_t ptrType = theBuilder.getPointerType(
+            typeTranslator.translateType(field->getType()),
+            spv::StorageClass::Function);
+        const uint32_t indexId = theBuilder.getInt32Value(fieldIndex++);
+        const uint32_t valuePtr =
+            theBuilder.createAccessChain(ptrType, retValue, {indexId});
+        const uint32_t value = theBuilder.createLoad(valueType, valuePtr);
+        // Store it to the corresponding stage variable.
+        const uint32_t targetVar = declIdMapper.getDeclResultId(field);
+        theBuilder.createStore(targetVar, value);
+      }
     } else {
     } else {
-      theBuilder.createReturnValue(retValue);
+      emitError("Return type '%0' for entry function is not supported yet.")
+          << retType->getTypeClassName();
     }
     }
   }
   }
 
 
@@ -497,32 +631,67 @@ public:
     } else if (auto *castExpr = dyn_cast<ImplicitCastExpr>(expr)) {
     } else if (auto *castExpr = dyn_cast<ImplicitCastExpr>(expr)) {
       const uint32_t fromValue = doExpr(castExpr->getSubExpr());
       const uint32_t fromValue = doExpr(castExpr->getSubExpr());
       // Using lvalue as rvalue will result in a ImplicitCast in Clang AST.
       // Using lvalue as rvalue will result in a ImplicitCast in Clang AST.
-      // This place gives us a place to inject the code for handling interface
-      // variables. Since using the <result-id> of an interface variable as
+      // This place gives us a place to inject the code for handling stage
+      // variables. Since using the <result-id> of a stage variable as
       // rvalue means OpLoad it first. For normal values, it is not required.
       // rvalue means OpLoad it first. For normal values, it is not required.
-      if (declIdMapper.isInterfaceVariable(fromValue)) {
+      if (declIdMapper.isStageVariable(fromValue)) {
         const uint32_t resultType =
         const uint32_t resultType =
-            translateType(castExpr->getType(), theBuilder);
+            typeTranslator.translateType(castExpr->getType());
         return theBuilder.createLoad(resultType, fromValue);
         return theBuilder.createLoad(resultType, fromValue);
       }
       }
       return fromValue;
       return fromValue;
+    } else if (auto *memberExpr = dyn_cast<MemberExpr>(expr)) {
+      const uint32_t base = doExpr(memberExpr->getBase());
+      auto *memberDecl = memberExpr->getMemberDecl();
+      if (auto *fieldDecl = dyn_cast<FieldDecl>(memberDecl)) {
+        const auto index = theBuilder.getInt32Value(fieldDecl->getFieldIndex());
+        const uint32_t fieldType =
+            typeTranslator.translateType(fieldDecl->getType());
+        const uint32_t ptrType =
+            theBuilder.getPointerType(fieldType, spv::StorageClass::Function);
+        return theBuilder.createAccessChain(ptrType, base, {index});
+      } else {
+        emitError("Decl '%0' in MemberExpr is not supported yet.")
+            << memberDecl->getDeclKindName();
+      }
     }
     }
+    emitError("Expr '%0' is not supported yet.") << expr->getStmtClassName();
     // TODO: handle other expressions
     // TODO: handle other expressions
     return 0;
     return 0;
   }
   }
 
 
+private:
+  /// \brief Wrapper method to create an error message and report it
+  /// in the diagnostic engine associated with this consumer.
+  template <unsigned N> DiagnosticBuilder emitError(const char (&message)[N]) {
+    const auto diagId =
+        diags.getCustomDiagID(clang::DiagnosticsEngine::Error, message);
+    return diags.Report(diagId);
+  }
+
+  /// \brief Wrapper method to create a warning message and report it
+  /// in the diagnostic engine associated with this consumer
+  template <unsigned N>
+  DiagnosticBuilder emitWarning(const char (&message)[N]) {
+    const auto diagId =
+        diags.getCustomDiagID(clang::DiagnosticsEngine::Warning, message);
+    return diags.Report(diagId);
+  }
+
 private:
 private:
   CompilerInstance &theCompilerInstance;
   CompilerInstance &theCompilerInstance;
   DiagnosticsEngine &diags;
   DiagnosticsEngine &diags;
-  spirv::SPIRVContext theContext;
-  spirv::ModuleBuilder theBuilder;
-  DeclResultIdMapper declIdMapper;
 
 
   /// Entry function name and shader stage. Both of them are derived from the
   /// Entry function name and shader stage. Both of them are derived from the
   /// command line and should be const.
   /// command line and should be const.
   const llvm::StringRef entryFunctionName;
   const llvm::StringRef entryFunctionName;
   const spv::ExecutionModel shaderStage;
   const spv::ExecutionModel shaderStage;
 
 
+  spirv::SPIRVContext theContext;
+  spirv::ModuleBuilder theBuilder;
+  DeclResultIdMapper declIdMapper;
+  TypeTranslator typeTranslator;
+
   /// <result-id> for the entry function. Initially it is zero and will be reset
   /// <result-id> for the entry function. Initially it is zero and will be reset
   /// when starting to translate the entry function.
   /// when starting to translate the entry function.
   uint32_t entryFunctionId;
   uint32_t entryFunctionId;

+ 27 - 0
tools/clang/lib/SPIRV/InstBuilderManual.cpp

@@ -15,6 +15,33 @@
 namespace clang {
 namespace clang {
 namespace spirv {
 namespace spirv {
 
 
+std::vector<uint32_t> InstBuilder::take() {
+  std::vector<uint32_t> result;
+
+  if (TheStatus == Status::Success && Expectation.empty() && !TheInst.empty()) {
+    TheInst.front() |= uint32_t(TheInst.size()) << 16;
+    result.swap(TheInst);
+  }
+
+  return result;
+}
+
+InstBuilder &InstBuilder::opConstant(uint32_t resultType, uint32_t resultId,
+                                     uint32_t value) {
+  if (!TheInst.empty()) {
+    TheStatus = Status::NestedInst;
+    return *this;
+  }
+
+  TheInst.reserve(4);
+  TheInst.emplace_back(static_cast<uint32_t>(spv::Op::OpConstant));
+  TheInst.emplace_back(resultType);
+  TheInst.emplace_back(resultId);
+  TheInst.emplace_back(value);
+
+  return *this;
+}
+
 void InstBuilder::encodeString(std::string value) {
 void InstBuilder::encodeString(std::string value) {
   const auto &words = string::encodeSPIRVString(value);
   const auto &words = string::encodeSPIRVString(value);
   TheInst.insert(TheInst.end(), words.begin(), words.end());
   TheInst.insert(TheInst.end(), words.begin(), words.end());

+ 86 - 17
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -47,11 +47,22 @@ uint32_t ModuleBuilder::beginFunction(uint32_t funcType, uint32_t returnType,
 uint32_t ModuleBuilder::addFnParameter(uint32_t type) {
 uint32_t ModuleBuilder::addFnParameter(uint32_t type) {
   assert(theFunction && "found detached parameter");
   assert(theFunction && "found detached parameter");
 
 
+  const uint32_t pointerType =
+      getPointerType(type, spv::StorageClass::Function);
   const uint32_t paramId = theContext.takeNextId();
   const uint32_t paramId = theContext.takeNextId();
-  theFunction->addParameter(type, paramId);
+  theFunction->addParameter(pointerType, paramId);
+
   return paramId;
   return paramId;
 }
 }
 
 
+uint32_t ModuleBuilder::addFnVariable(uint32_t type) {
+  assert(theFunction && "found detached local variable");
+
+  const uint32_t varId = theContext.takeNextId();
+  theFunction->addVariable(type, varId);
+  return varId;
+}
+
 bool ModuleBuilder::endFunction() {
 bool ModuleBuilder::endFunction() {
   if (theFunction == nullptr) {
   if (theFunction == nullptr) {
     assert(false && "no active function");
     assert(false && "no active function");
@@ -100,26 +111,35 @@ uint32_t ModuleBuilder::createLoad(uint32_t resultType, uint32_t pointer) {
   assert(insertPoint && "null insert point");
   assert(insertPoint && "null insert point");
   const uint32_t resultId = theContext.takeNextId();
   const uint32_t resultId = theContext.takeNextId();
   instBuilder.opLoad(resultType, resultId, pointer, llvm::None).x();
   instBuilder.opLoad(resultType, resultId, pointer, llvm::None).x();
-  insertPoint->addInstruction(std::move(constructSite));
+  insertPoint->appendInstruction(std::move(constructSite));
   return resultId;
   return resultId;
 }
 }
 
 
 void ModuleBuilder::createStore(uint32_t address, uint32_t value) {
 void ModuleBuilder::createStore(uint32_t address, uint32_t value) {
   assert(insertPoint && "null insert point");
   assert(insertPoint && "null insert point");
   instBuilder.opStore(address, value, llvm::None).x();
   instBuilder.opStore(address, value, llvm::None).x();
-  insertPoint->addInstruction(std::move(constructSite));
+  insertPoint->appendInstruction(std::move(constructSite));
+}
+
+uint32_t ModuleBuilder::createAccessChain(uint32_t resultType, uint32_t base,
+                                          llvm::ArrayRef<uint32_t> indexes) {
+  assert(insertPoint && "null insert point");
+  const uint32_t id = theContext.takeNextId();
+  instBuilder.opAccessChain(resultType, id, base, indexes).x();
+  insertPoint->appendInstruction(std::move(constructSite));
+  return id;
 }
 }
 
 
 void ModuleBuilder::createReturn() {
 void ModuleBuilder::createReturn() {
   assert(insertPoint && "null insert point");
   assert(insertPoint && "null insert point");
   instBuilder.opReturn().x();
   instBuilder.opReturn().x();
-  insertPoint->addInstruction(std::move(constructSite));
+  insertPoint->appendInstruction(std::move(constructSite));
 }
 }
 
 
 void ModuleBuilder::createReturnValue(uint32_t value) {
 void ModuleBuilder::createReturnValue(uint32_t value) {
   assert(insertPoint && "null insert point");
   assert(insertPoint && "null insert point");
   instBuilder.opReturnValue(value).x();
   instBuilder.opReturnValue(value).x();
-  insertPoint->addInstruction(std::move(constructSite));
+  insertPoint->appendInstruction(std::move(constructSite));
 }
 }
 
 
 uint32_t ModuleBuilder::getVoidType() {
 uint32_t ModuleBuilder::getVoidType() {
@@ -129,29 +149,46 @@ uint32_t ModuleBuilder::getVoidType() {
   return typeId;
   return typeId;
 }
 }
 
 
-uint32_t ModuleBuilder::getFloatType() {
-  const Type *type = Type::getFloat32(theContext);
+uint32_t ModuleBuilder::getInt32Type() {
+  const Type *type = Type::getInt32(theContext);
   const uint32_t typeId = theContext.getResultIdForType(type);
   const uint32_t typeId = theContext.getResultIdForType(type);
   theModule.addType(type, typeId);
   theModule.addType(type, typeId);
   return typeId;
   return typeId;
 }
 }
 
 
-uint32_t ModuleBuilder::getVec2Type(uint32_t elemType) {
-  const Type *type = Type::getVec2(theContext, elemType);
+uint32_t ModuleBuilder::getFloatType() {
+  const Type *type = Type::getFloat32(theContext);
   const uint32_t typeId = theContext.getResultIdForType(type);
   const uint32_t typeId = theContext.getResultIdForType(type);
   theModule.addType(type, typeId);
   theModule.addType(type, typeId);
   return typeId;
   return typeId;
 }
 }
 
 
-uint32_t ModuleBuilder::getVec3Type(uint32_t elemType) {
-  const Type *type = Type::getVec3(theContext, elemType);
+uint32_t ModuleBuilder::getVecType(uint32_t elemType, uint32_t elemCount) {
+  const Type *type = nullptr;
+  switch (elemCount) {
+  case 2:
+    type = Type::getVec2(theContext, elemType);
+    break;
+  case 3:
+    type = Type::getVec3(theContext, elemType);
+    break;
+  case 4:
+    type = Type::getVec4(theContext, elemType);
+    break;
+  default:
+    assert(false && "unhandled vector size");
+    // Error found. Return 0 as the <result-id> directly.
+    return 0;
+  }
+
   const uint32_t typeId = theContext.getResultIdForType(type);
   const uint32_t typeId = theContext.getResultIdForType(type);
   theModule.addType(type, typeId);
   theModule.addType(type, typeId);
+
   return typeId;
   return typeId;
 }
 }
 
 
-uint32_t ModuleBuilder::getVec4Type(uint32_t elemType) {
-  const Type *type = Type::getVec4(theContext, elemType);
+uint32_t ModuleBuilder::getStructType(llvm::ArrayRef<uint32_t> fieldTypes) {
+  const Type *type = Type::getStruct(theContext, fieldTypes);
   const uint32_t typeId = theContext.getResultIdForType(type);
   const uint32_t typeId = theContext.getResultIdForType(type);
   theModule.addType(type, typeId);
   theModule.addType(type, typeId);
   return typeId;
   return typeId;
@@ -174,13 +211,45 @@ uint32_t ModuleBuilder::getPointerType(uint32_t pointeeType,
   return typeId;
   return typeId;
 }
 }
 
 
-uint32_t
-ModuleBuilder::addStageIOVariable(uint32_t type, spv::StorageClass storageClass,
-                                  llvm::Optional<uint32_t> initilizer) {
+uint32_t ModuleBuilder::getInt32Value(uint32_t value) {
+  const Type *i32Type = Type::getInt32(theContext);
+  const uint32_t i32TypeId = getInt32Type();
+  const uint32_t constantId = theContext.takeNextId();
+  instBuilder.opConstant(i32TypeId, constantId, value).x();
+  theModule.addConstant(*i32Type, std::move(constructSite));
+  return constantId;
+}
+
+uint32_t ModuleBuilder::addStageIOVariable(uint32_t type,
+                                           spv::StorageClass storageClass) {
   const uint32_t pointerType = getPointerType(type, storageClass);
   const uint32_t pointerType = getPointerType(type, storageClass);
   const uint32_t varId = theContext.takeNextId();
   const uint32_t varId = theContext.takeNextId();
-  instBuilder.opVariable(pointerType, varId, storageClass, initilizer).x();
+  instBuilder.opVariable(pointerType, varId, storageClass, llvm::None).x();
+  theModule.addVariable(std::move(constructSite));
+  return varId;
+}
+
+uint32_t ModuleBuilder::addStageBuiltinVariable(uint32_t type,
+                                                spv::BuiltIn builtin) {
+  spv::StorageClass sc = spv::StorageClass::Input;
+  switch (builtin) {
+  case spv::BuiltIn::Position:
+  case spv::BuiltIn::PointSize:
+    // TODO: add the rest output builtins
+    sc = spv::StorageClass::Output;
+    break;
+  default:
+    break;
+  }
+  const uint32_t pointerType = getPointerType(type, sc);
+  const uint32_t varId = theContext.takeNextId();
+  instBuilder.opVariable(pointerType, varId, sc, llvm::None).x();
   theModule.addVariable(std::move(constructSite));
   theModule.addVariable(std::move(constructSite));
+
+  // Decorate with the specified Builtin
+  const Decoration *d = Decoration::getBuiltIn(theContext, builtin);
+  theModule.addDecoration(*d, varId);
+
   return varId;
   return varId;
 }
 }
 
 

+ 16 - 0
tools/clang/lib/SPIRV/Structure.cpp

@@ -78,6 +78,7 @@ Function &Function::operator=(Function &&that) {
   funcControl = that.funcControl;
   funcControl = that.funcControl;
   funcType = that.funcType;
   funcType = that.funcType;
   parameters = std::move(that.parameters);
   parameters = std::move(that.parameters);
+  variables = std::move(that.variables);
   blocks = std::move(that.blocks);
   blocks = std::move(that.blocks);
 
 
   that.clear();
   that.clear();
@@ -91,17 +92,32 @@ void Function::clear() {
   funcControl = spv::FunctionControlMask::MaskNone;
   funcControl = spv::FunctionControlMask::MaskNone;
   funcType = 0;
   funcType = 0;
   parameters.clear();
   parameters.clear();
+  variables.clear();
   blocks.clear();
   blocks.clear();
 }
 }
 
 
 void Function::take(InstBuilder *builder) {
 void Function::take(InstBuilder *builder) {
   builder->opFunction(resultType, resultId, funcControl, funcType).x();
   builder->opFunction(resultType, resultId, funcControl, funcType).x();
+
+  // Write out all parameters.
   for (auto &param : parameters) {
   for (auto &param : parameters) {
     builder->opFunctionParameter(param.first, param.second).x();
     builder->opFunctionParameter(param.first, param.second).x();
   }
   }
+
+  // Preprend all local variables to the entry block.
+  for (auto &var : variables) {
+    blocks.front()->prependInstruction(
+        builder
+            ->opVariable(var.first, var.second, spv::StorageClass::Function,
+                         llvm::None)
+            .take());
+  }
+
+  // Write out all basic blocks.
   for (auto &block : blocks) {
   for (auto &block : blocks) {
     block->take(builder);
     block->take(builder);
   }
   }
+
   builder->opFunctionEnd().x();
   builder->opFunctionEnd().x();
   clear();
   clear();
 }
 }

+ 2 - 3
tools/clang/lib/SPIRV/Type.cpp

@@ -16,7 +16,7 @@ namespace spirv {
 
 
 Type::Type(spv::Op op, std::vector<uint32_t> arg,
 Type::Type(spv::Op op, std::vector<uint32_t> arg,
            std::set<const Decoration *> decs)
            std::set<const Decoration *> decs)
-    : opcode(op), args(arg), decorations(decs) {}
+    : opcode(op), args(std::move(arg)), decorations(std::move(decs)) {}
 
 
 const Type *Type::getUniqueType(SPIRVContext &context, const Type &t) {
 const Type *Type::getUniqueType(SPIRVContext &context, const Type &t) {
   return context.registerType(t);
   return context.registerType(t);
@@ -125,8 +125,7 @@ const Type *Type::getRuntimeArray(SPIRVContext &context,
   return getUniqueType(context, t);
   return getUniqueType(context, t);
 }
 }
 const Type *Type::getStruct(SPIRVContext &context,
 const Type *Type::getStruct(SPIRVContext &context,
-                            std::initializer_list<uint32_t> members,
-                            DecorationSet d) {
+                            llvm::ArrayRef<uint32_t> members, DecorationSet d) {
   Type t = Type(spv::Op::OpTypeStruct, std::vector<uint32_t>(members), d);
   Type t = Type(spv::Op::OpTypeStruct, std::vector<uint32_t>(members), d);
   return getUniqueType(context, t);
   return getUniqueType(context, t);
 }
 }

+ 5 - 5
tools/clang/test/CodeGenSPIRV/passthru-ps.hlsl2spv

@@ -12,22 +12,22 @@ float4 main(float4 input: COLOR): SV_TARGET
 // ; Schema: 0
 // ; Schema: 0
 // OpCapability Shader
 // OpCapability Shader
 // OpMemoryModel Logical GLSL450
 // OpMemoryModel Logical GLSL450
-// OpEntryPoint Fragment %main "main" %7 %4
+// OpEntryPoint Fragment %main "main" %6 %4
 // OpExecutionMode %main OriginUpperLeft
 // OpExecutionMode %main OriginUpperLeft
 // OpName %main "main"
 // OpName %main "main"
-// OpDecorate %7 Location 0
+// OpDecorate %6 Location 0
 // OpDecorate %4 Location 0
 // OpDecorate %4 Location 0
 // %float = OpTypeFloat 32
 // %float = OpTypeFloat 32
 // %v4float = OpTypeVector %float 4
 // %v4float = OpTypeVector %float 4
 // %_ptr_Output_v4float = OpTypePointer Output %v4float
 // %_ptr_Output_v4float = OpTypePointer Output %v4float
-// %void = OpTypeVoid
 // %_ptr_Input_v4float = OpTypePointer Input %v4float
 // %_ptr_Input_v4float = OpTypePointer Input %v4float
+// %void = OpTypeVoid
 // %8 = OpTypeFunction %void
 // %8 = OpTypeFunction %void
 // %4 = OpVariable %_ptr_Output_v4float Output
 // %4 = OpVariable %_ptr_Output_v4float Output
-// %7 = OpVariable %_ptr_Input_v4float Input
+// %6 = OpVariable %_ptr_Input_v4float Input
 // %main = OpFunction %void None %8
 // %main = OpFunction %void None %8
 // %10 = OpLabel
 // %10 = OpLabel
-// %11 = OpLoad %v4float %7
+// %11 = OpLoad %v4float %6
 // OpStore %4 %11
 // OpStore %4 %11
 // OpReturn
 // OpReturn
 // OpFunctionEnd
 // OpFunctionEnd

+ 37 - 24
tools/clang/test/CodeGenSPIRV/passthru-vs.hlsl2spv

@@ -1,9 +1,5 @@
 // Run: %dxc -T vs_6_0 -E VSmain
 // Run: %dxc -T vs_6_0 -E VSmain
 
 
-/*
-// This is the proper way to do it, but our code doesn't handle PSInput
-// as return type for now.
-
 struct PSInput {
 struct PSInput {
   float4 position : POSITION;
   float4 position : POSITION;
   float4 color : COLOR;
   float4 color : COLOR;
@@ -15,40 +11,57 @@ PSInput VSmain(float4 position: POSITION, float4 color: COLOR) {
   result.color = color;
   result.color = color;
   return result;
   return result;
 }
 }
-*/
-
-float4 VSmain(float4 position: POSITION, float4 color: COLOR) {
-  return color;
-}
 
 
 // TODO:
 // TODO:
-// Proper representation for the structure.
-// Input/Output interfaces for OpEntryPoint
-// Proper logic to determine ExecutionMode
-// Semantics
-// Function return value
+// Deduplicate integer constants
 
 
 
 
 // CHECK-WHOLE-SPIR-V:
 // CHECK-WHOLE-SPIR-V:
 // ; SPIR-V
 // ; SPIR-V
 // ; Version: 1.0
 // ; Version: 1.0
 // ; Generator: Google spiregg; 0
 // ; Generator: Google spiregg; 0
-// ; Bound: 11
+// ; Bound: 30
 // ; Schema: 0
 // ; Schema: 0
 // OpCapability Shader
 // OpCapability Shader
 // OpMemoryModel Logical GLSL450
 // OpMemoryModel Logical GLSL450
-// OpEntryPoint Vertex %VSmain "VSmain" %4
+// OpEntryPoint Vertex %VSmain "VSmain" %gl_Position %7 %8 %5
 // OpName %VSmain "VSmain"
 // OpName %VSmain "VSmain"
-// OpDecorate %4 Location 0
+// OpDecorate %gl_Position BuiltIn Position
+// OpDecorate %7 Location 0
+// OpDecorate %8 Location 1
+// OpDecorate %5 Location 0
 // %float = OpTypeFloat 32
 // %float = OpTypeFloat 32
 // %v4float = OpTypeVector %float 4
 // %v4float = OpTypeVector %float 4
+// %_ptr_Output_v4float = OpTypePointer Output %v4float
 // %_ptr_Input_v4float = OpTypePointer Input %v4float
 // %_ptr_Input_v4float = OpTypePointer Input %v4float
-// %5 = OpTypeFunction %v4float %v4float
 // %void = OpTypeVoid
 // %void = OpTypeVoid
-// %4 = OpVariable %_ptr_Input_v4float Input
-// %VSmain = OpFunction %v4float None %5
-// %7 = OpFunctionParameter %v4float
-// %8 = OpLabel
-// %9 = OpLoad %v4float %4
-// OpReturnValue %9
+// %10 = OpTypeFunction %void
+// %_struct_13 = OpTypeStruct %v4float %v4float
+// %_ptr_Function__struct_13 = OpTypePointer Function %_struct_13
+// %int = OpTypeInt 32 1
+// %_ptr_Function_v4float = OpTypePointer Function %v4float
+// %int_0 = OpConstant %int 0
+// %int_1 = OpConstant %int 1
+// %int_0_0 = OpConstant %int 0
+// %int_1_0 = OpConstant %int 1
+// %gl_Position = OpVariable %_ptr_Output_v4float Output
+// %5 = OpVariable %_ptr_Output_v4float Output
+// %7 = OpVariable %_ptr_Input_v4float Input
+// %8 = OpVariable %_ptr_Input_v4float Input
+// %VSmain = OpFunction %void None %10
+// %12 = OpLabel
+// %15 = OpVariable %_ptr_Function__struct_13 Function
+// %19 = OpAccessChain %_ptr_Function_v4float %15 %int_0
+// %20 = OpLoad %v4float %7
+// OpStore %19 %20
+// %22 = OpAccessChain %_ptr_Function_v4float %15 %int_1
+// %23 = OpLoad %v4float %8
+// OpStore %22 %23
+// %25 = OpAccessChain %_ptr_Function_v4float %15 %int_0_0
+// %26 = OpLoad %v4float %25
+// OpStore %gl_Position %26
+// %28 = OpAccessChain %_ptr_Function_v4float %15 %int_1_0
+// %29 = OpLoad %v4float %28
+// OpStore %5 %29
+// OpReturn
 // OpFunctionEnd
 // OpFunctionEnd

+ 3 - 3
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -15,17 +15,17 @@ TEST_F(WholeFileTest, EmptyVoidMain) {
   // TODO: change this test such that it does run validation.
   // TODO: change this test such that it does run validation.
   runWholeFileTest("empty-void-main.hlsl2spv",
   runWholeFileTest("empty-void-main.hlsl2spv",
                    /*generateHeader*/ true,
                    /*generateHeader*/ true,
-                   /*runValidation*/ false);
+                   /*runValidation*/ true);
 }
 }
 
 
 TEST_F(WholeFileTest, PassThruPixelShader) {
 TEST_F(WholeFileTest, PassThruPixelShader) {
   runWholeFileTest("passthru-ps.hlsl2spv",
   runWholeFileTest("passthru-ps.hlsl2spv",
                    /*generateHeader*/ true,
                    /*generateHeader*/ true,
-                   /*runValidation*/ false);
+                   /*runValidation*/ true);
 }
 }
 
 
 TEST_F(WholeFileTest, PassThruVertexShader) {
 TEST_F(WholeFileTest, PassThruVertexShader) {
   runWholeFileTest("passthru-vs.hlsl2spv",
   runWholeFileTest("passthru-vs.hlsl2spv",
                    /*generateHeader*/ true,
                    /*generateHeader*/ true,
-                   /*runValidation*/ false);
+                   /*runValidation*/ true);
 }
 }

+ 4 - 4
tools/clang/unittests/SPIRV/StructureTest.cpp

@@ -28,7 +28,7 @@ TEST(Structure, TakeBasicBlockHaveAllContents) {
   auto ib = constructInstBuilder(result);
   auto ib = constructInstBuilder(result);
 
 
   auto bb = BasicBlock(42);
   auto bb = BasicBlock(42);
-  bb.addInstruction(constructInst(spv::Op::OpReturn, {}));
+  bb.appendInstruction(constructInst(spv::Op::OpReturn, {}));
   bb.take(&ib);
   bb.take(&ib);
 
 
   std::vector<uint32_t> expected;
   std::vector<uint32_t> expected;
@@ -41,7 +41,7 @@ TEST(Structure, TakeBasicBlockHaveAllContents) {
 
 
 TEST(Structure, AfterClearBasicBlockIsEmpty) {
 TEST(Structure, AfterClearBasicBlockIsEmpty) {
   auto bb = BasicBlock(42);
   auto bb = BasicBlock(42);
-  bb.addInstruction(constructInst(spv::Op::OpNop, {}));
+  bb.appendInstruction(constructInst(spv::Op::OpNop, {}));
   EXPECT_FALSE(bb.isEmpty());
   EXPECT_FALSE(bb.isEmpty());
   bb.clear();
   bb.clear();
   EXPECT_TRUE(bb.isEmpty());
   EXPECT_TRUE(bb.isEmpty());
@@ -57,7 +57,7 @@ TEST(Structure, TakeFunctionHaveAllContents) {
   f.addParameter(1, 42);
   f.addParameter(1, 42);
 
 
   auto bb = llvm::make_unique<BasicBlock>(10);
   auto bb = llvm::make_unique<BasicBlock>(10);
-  bb->addInstruction(constructInst(spv::Op::OpReturn, {}));
+  bb->appendInstruction(constructInst(spv::Op::OpReturn, {}));
   f.addBasicBlock(std::move(bb));
   f.addBasicBlock(std::move(bb));
 
 
   std::vector<uint32_t> result;
   std::vector<uint32_t> result;
@@ -189,7 +189,7 @@ TEST(Structure, TakeModuleHaveAllContents) {
       voidId, funcId, spv::FunctionControlMask::MaskNone, funcTypeId);
       voidId, funcId, spv::FunctionControlMask::MaskNone, funcTypeId);
   const uint32_t bbId = context.takeNextId();
   const uint32_t bbId = context.takeNextId();
   auto bb = llvm::make_unique<BasicBlock>(bbId);
   auto bb = llvm::make_unique<BasicBlock>(bbId);
-  bb->addInstruction(constructInst(spv::Op::OpReturn, {}));
+  bb->appendInstruction(constructInst(spv::Op::OpReturn, {}));
   f->addBasicBlock(std::move(bb));
   f->addBasicBlock(std::move(bb));
   m.addFunction(std::move(f));
   m.addFunction(std::move(f));
   appendVector(&expected, constructInst(spv::Op::OpFunction,
   appendVector(&expected, constructInst(spv::Op::OpFunction,

+ 4 - 4
tools/clang/unittests/SPIRV/WholeFileCheck.cpp

@@ -216,11 +216,11 @@ void WholeFileTest::runWholeFileTest(std::string filename, bool generateHeader,
   // Disassemble the generated SPIR-V binary.
   // Disassemble the generated SPIR-V binary.
   ASSERT_TRUE(disassembleSpirvBinary(generateHeader));
   ASSERT_TRUE(disassembleSpirvBinary(generateHeader));
 
 
+  // Compare the expected and the generted SPIR-V code.
+  EXPECT_EQ(expectedSpirvAsm, generatedSpirvAsm);
+
   // Run SPIR-V validation if requested.
   // Run SPIR-V validation if requested.
   if (runSpirvValidation) {
   if (runSpirvValidation) {
-    ASSERT_TRUE(validateSpirvBinary());
+    EXPECT_TRUE(validateSpirvBinary());
   }
   }
-
-  // Compare the expected and the generted SPIR-V code.
-  EXPECT_EQ(expectedSpirvAsm, generatedSpirvAsm);
 }
 }