Browse Source

[spirv] Fully translate pass-through pixel shader! (#439)

* Add support for generating stage IO variables
* Add support for interface variable ids in entry point
* Add support for OpReturnValue
Lei Zhang 8 years ago
parent
commit
675151e6de

+ 62 - 17
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -24,16 +24,28 @@ namespace spirv {
 /// This class exports API for constructing SPIR-V binary interactively.
 /// At any time, there can only exist at most one function under building;
 /// but there can exist multiple basic blocks under construction.
+///
+/// Call `takeModule()` to get the SPIR-V words after finishing building the
+/// module.
 class ModuleBuilder {
 public:
   /// \brief Constructs a ModuleBuilder with the given SPIR-V context.
   explicit ModuleBuilder(SPIRVContext *);
 
+  /// \brief Takes the SPIR-V module under building. This will consume the
+  /// module under construction.
+  std::vector<uint32_t> takeModule();
+
+  // === Function and Basic Block ===
+
   /// \brief Begins building a SPIR-V function. At any time, there can only
   /// exist at most one function under building. Returns the <result-id> for the
   /// function on success. Returns zero on failure.
-  uint32_t beginFunction(uint32_t funcType, uint32_t returnType,
-                         const std::vector<uint32_t> &paramTypeIds = {});
+  uint32_t beginFunction(uint32_t funcType, uint32_t returnType);
+
+  /// \brief Registers a function parameter of the given type onto the current
+  /// function under construction and returns its <result-id>.
+  uint32_t addFnParameter(uint32_t type);
 
   /// \brief Ends building of the current function. Returns true of success,
   /// false on failure. All basic blocks constructed from the beginning or
@@ -42,57 +54,90 @@ public:
 
   /// \brief Creates a SPIR-V basic block. On success, returns the <label-id>
   /// for the basic block. On failure, returns zero.
-  uint32_t bbCreate();
+  uint32_t createBasicBlock();
 
-  /// \brief Ends building the SPIR-V basic block having the given <label-id>
-  /// with OpReturn. Returns true on success, false on failure.
-  bool bbReturn(uint32_t labelId);
+  /// \brief Returns true if the current basic block inserting into is
+  /// terminated.
+  inline bool isCurrentBasicBlockTerminated() const;
 
   /// \brief Sets insertion point to the basic block with the given <label-id>.
   /// Returns true on success, false on failure.
   bool setInsertPoint(uint32_t labelId);
 
+  // === Instruction at the current Insertion Point ===
+
+  /// \brief Creates a load instruction loading the value of the given
+  /// <result-type> from the given pointer. Returns the <result-id> for the
+  /// loaded value.
+  uint32_t createLoad(uint32_t resultType, uint32_t pointer);
+  /// \brief Creates a store instruction storing the given value into the given
+  /// address.
+  void createStore(uint32_t address, uint32_t value);
+
+  /// \brief Creates a return instruction.
+  void createReturn();
+  /// \brief Creates a return value instruction.
+  void createReturnValue(uint32_t value);
+
+  // === SPIR-V Module Structure ===
+
   inline void requireCapability(spv::Capability);
+
   inline void setAddressingModel(spv::AddressingModel);
   inline void setMemoryModel(spv::MemoryModel);
 
-  /// \brief Adds an Entry Point for the module under construction. We only
+  /// \brief Adds an entry point for the module under construction. We only
   /// support a single entry point per module for now.
   inline void addEntryPoint(spv::ExecutionModel em, uint32_t targetId,
                             std::string targetName,
-                            std::initializer_list<uint32_t> interfaces);
+                            llvm::ArrayRef<uint32_t> interfaces);
 
-  /// \brief Adds an Execution Mode to the module under construction.
+  /// \brief Adds an execution mode to the module under construction.
   inline void addExecutionMode(uint32_t entryPointId, spv::ExecutionMode em,
                                const std::vector<uint32_t> &params);
 
+  /// \brief Adds a stage input/ouput variable whose value is of the given type.
+  ///
+  /// The corresponding pointer type of the given type will be constructed in
+  /// this method for the variable itself.
+  uint32_t addStageIOVariable(uint32_t type, spv::StorageClass storageClass,
+                              llvm::Optional<uint32_t> initializer);
+
+  /// \brief Decorates the given target <result-id> with the given location.
+  void decorateLocation(uint32_t targetId, uint32_t location);
+
+  // === Type ===
+
   uint32_t getVoidType();
   uint32_t getFloatType();
   uint32_t getVec2Type(uint32_t elemType);
   uint32_t getVec3Type(uint32_t elemType);
   uint32_t getVec4Type(uint32_t elemType);
+  uint32_t getPointerType(uint32_t pointeeType, spv::StorageClass);
   uint32_t getFunctionType(uint32_t returnType,
                            const std::vector<uint32_t> &paramTypes);
 
-  /// \brief Takes the SPIR-V module under building. This will consume the
-  /// module under construction.
-  std::vector<uint32_t> takeModule();
-
 private:
   /// \brief Map from basic blocks' <label-id> to their structured
   /// representation.
   using OrderedBasicBlockMap =
       llvm::MapVector<uint32_t, std::unique_ptr<BasicBlock>>;
 
-  SPIRVContext &theContext;              ///< The SPIR-V context.
-  SPIRVModule theModule;                 ///< The module under building.
+  SPIRVContext &theContext; ///< The SPIR-V context.
+  SPIRVModule theModule;    ///< The module under building.
+
   std::unique_ptr<Function> theFunction; ///< The function under building.
   OrderedBasicBlockMap basicBlocks;      ///< The basic blocks under building.
   BasicBlock *insertPoint;               ///< The current insertion point.
-  std::vector<uint32_t> constructSite;   ///< InstBuilder construction site.
+
+  std::vector<uint32_t> constructSite; ///< InstBuilder construction site.
   InstBuilder instBuilder;
 };
 
+bool ModuleBuilder::isCurrentBasicBlockTerminated() const {
+  return insertPoint && insertPoint->isTerminated();
+}
+
 void ModuleBuilder::setAddressingModel(spv::AddressingModel am) {
   theModule.setAddressingModel(am);
 }
@@ -107,7 +152,7 @@ void ModuleBuilder::requireCapability(spv::Capability cap) {
 
 void ModuleBuilder::addEntryPoint(spv::ExecutionModel em, uint32_t targetId,
                                   std::string targetName,
-                                  std::initializer_list<uint32_t> interfaces) {
+                                  llvm::ArrayRef<uint32_t> interfaces) {
   theModule.addEntryPoint(em, targetId, targetName, interfaces);
 }
 

+ 21 - 18
tools/clang/include/clang/SPIRV/Structure.h

@@ -25,6 +25,7 @@
 #include "spirv/1.0/spirv.hpp11"
 #include "clang/SPIRV/InstBuilder.h"
 #include "clang/SPIRV/Type.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/Optional.h"
 
@@ -62,9 +63,12 @@ public:
   /// state.
   void take(InstBuilder *builder);
 
-  /// \brief add an instruction to this basic block.
+  /// \brief Add an instruction to this basic block.
   inline void addInstruction(Instruction &&);
 
+  /// \brief Returns true if this basic block is terminated.
+  bool isTerminated() const;
+
 private:
   uint32_t labelId; ///< The label id for this basic block. Zero means invalid.
   std::vector<Instruction> instructions;
@@ -140,12 +144,12 @@ struct ExtInstSet {
 /// \brief The struct representing an entry point.
 struct EntryPoint {
   inline EntryPoint(spv::ExecutionModel, uint32_t id, std::string name,
-                    std::initializer_list<uint32_t> interface);
+                    const std::vector<uint32_t> &interface);
 
   const spv::ExecutionModel executionModel;
   const uint32_t targetId;
   const std::string targetName;
-  const std::initializer_list<uint32_t> interfaces;
+  const std::vector<uint32_t> interfaces;
 };
 
 /// \brief The struct representing a debug name.
@@ -217,7 +221,7 @@ public:
   inline void setMemoryModel(spv::MemoryModel);
   inline void addEntryPoint(spv::ExecutionModel, uint32_t targetId,
                             std::string targetName,
-                            std::initializer_list<uint32_t> intefaces);
+                            llvm::ArrayRef<uint32_t> intefaces);
   inline void addExecutionMode(Instruction &&);
   // TODO: source code debug information
   inline void addDebugName(uint32_t targetId,
@@ -226,7 +230,7 @@ public:
   inline void addDecoration(const Decoration &decoration, uint32_t targetId);
   inline void addType(const Type *type, uint32_t resultId);
   inline void addConstant(const Type &type, Instruction &&constant);
-  // TODO: global variables
+  inline void addVariable(Instruction &&);
   inline void addFunction(std::unique_ptr<Function>);
 
 private:
@@ -250,7 +254,7 @@ private:
   // they should be handled together.
   llvm::MapVector<const Type *, uint32_t> types;
   std::vector<Constant> constants;
-  // TODO: global variables
+  std::vector<Instruction> variables;
   std::vector<std::unique_ptr<Function>> functions;
 };
 
@@ -294,22 +298,19 @@ void Function::addBasicBlock(std::unique_ptr<BasicBlock> block) {
 ExtInstSet::ExtInstSet(uint32_t id, std::string name)
     : resultId(id), setName(name) {}
 
-EntryPoint::EntryPoint(spv::ExecutionModel em, uint32_t id,
-                                    std::string name,
-                                    std::initializer_list<uint32_t> interface)
+EntryPoint::EntryPoint(spv::ExecutionModel em, uint32_t id, std::string name,
+                       const std::vector<uint32_t> &interface)
     : executionModel(em), targetId(id), targetName(std::move(name)),
-      interfaces(std::move(interface)) {}
+      interfaces(interface) {}
 
 DebugName::DebugName(uint32_t id, llvm::Optional<uint32_t> index,
-                                  std::string targetName)
+                     std::string targetName)
     : targetId(id), memberIndex(index), name(std::move(targetName)) {}
 
-DecorationIdPair::DecorationIdPair(const Decoration &decor,
-                                                uint32_t id)
+DecorationIdPair::DecorationIdPair(const Decoration &decor, uint32_t id)
     : decoration(decor), targetId(id) {}
 
-TypeIdPair::TypeIdPair(const Type &ty, uint32_t id)
-    : type(ty), resultId(id) {}
+TypeIdPair::TypeIdPair(const Type &ty, uint32_t id) : type(ty), resultId(id) {}
 
 Constant::Constant(const Type &ty, Instruction &&value)
     : type(ty), constant(std::move(value)) {}
@@ -336,9 +337,8 @@ void SPIRVModule::setMemoryModel(spv::MemoryModel mm) {
 }
 void SPIRVModule::addEntryPoint(spv::ExecutionModel em, uint32_t targetId,
                                 std::string name,
-                                std::initializer_list<uint32_t> interfaces) {
-  entryPoints.emplace_back(em, targetId, std::move(name),
-                           std::move(interfaces));
+                                llvm::ArrayRef<uint32_t> interfaces) {
+  entryPoints.emplace_back(em, targetId, std::move(name), interfaces);
 }
 void SPIRVModule::addExecutionMode(Instruction &&execMode) {
   executionModes.push_back(std::move(execMode));
@@ -358,6 +358,9 @@ void SPIRVModule::addType(const Type *type, uint32_t resultId) {
 void SPIRVModule::addConstant(const Type &type, Instruction &&constant) {
   constants.emplace_back(type, std::move(constant));
 };
+void SPIRVModule::addVariable(Instruction &&var) {
+  variables.push_back(std::move(var));
+}
 void SPIRVModule::addFunction(std::unique_ptr<Function> f) {
   functions.push_back(std::move(f));
 }

+ 398 - 73
tools/clang/lib/SPIRV/EmitSPIRVAction.cpp

@@ -11,6 +11,7 @@
 #include "clang/AST/AST.h"
 #include "clang/AST/ASTConsumer.h"
 #include "clang/AST/ASTContext.h"
+#include "clang/AST/HlslTypes.h"
 #include "clang/AST/RecordLayout.h"
 #include "clang/AST/RecursiveASTVisitor.h"
 #include "clang/Basic/Diagnostic.h"
@@ -22,22 +23,32 @@
 #include "llvm/Support/raw_ostream.h"
 
 namespace {
+
 spv::ExecutionModel getSpirvShaderStageFromHlslProfile(const char *profile) {
   // DXIL Models are:
-  // Profile (DXIL Model) : HLSL Shader Kind : SPIR-V Shader Kind
+  // Profile (DXIL Model) : HLSL Shader Kind : SPIR-V Shader Stage
   // vs_<version>         : Vertex Shader    : Vertex Shader
   // hs_<version>         : Hull Shader      : Tassellation Control Shader
   // ds_<version>         : Domain Shader    : Tessellation Evaluation Shader
   // gs_<version>         : Geometry Shader  : Geometry Shader
   // ps_<version>         : Pixel Shader     : Fragment Shader
   // cs_<version>         : Compute Shader   : Compute Shader
+
+  assert(profile && "nullptr passed in as profile");
+
   switch (profile[0]) {
-  case 'v': return spv::ExecutionModel::Vertex;
-  case 'h': return spv::ExecutionModel::TessellationControl;
-  case 'd': return spv::ExecutionModel::TessellationEvaluation;
-  case 'g': return spv::ExecutionModel::Geometry;
-  case 'p': return spv::ExecutionModel::Fragment;
-  case 'c': return spv::ExecutionModel::GLCompute;
+  case 'v':
+    return spv::ExecutionModel::Vertex;
+  case 'h':
+    return spv::ExecutionModel::TessellationControl;
+  case 'd':
+    return spv::ExecutionModel::TessellationEvaluation;
+  case 'g':
+    return spv::ExecutionModel::Geometry;
+  case 'p':
+    return spv::ExecutionModel::Fragment;
+  case 'c':
+    return spv::ExecutionModel::GLCompute;
   default:
     assert(false && "Unknown HLSL Profile");
     return spv::ExecutionModel::Fragment;
@@ -49,11 +60,269 @@ spv::ExecutionModel getSpirvShaderStageFromHlslProfile(const char *profile) {
 namespace clang {
 namespace {
 
+/// \brief Generates the corresponding SPIR-V type for the given Clang frontend
+/// type and returns the <result-id>.
+///
+/// The translation is recursive; all the types that the target type depends on
+/// will be generated.
+uint32_t translateType(QualType type, spirv::ModuleBuilder &theBuilder) {
+  // In AST, vector types are TypedefType of TemplateSpecializationType,
+  // which is nested deeply. So we do fast track check here.
+  const auto symbol = type.getAsString();
+  if (symbol == "float4") {
+    const uint32_t floatType = theBuilder.getFloatType();
+    return theBuilder.getVec4Type(floatType);
+  } else if (symbol == "float3") {
+    const uint32_t floatType = theBuilder.getFloatType();
+    return theBuilder.getVec3Type(floatType);
+  } else if (symbol == "float2") {
+    const uint32_t floatType = theBuilder.getFloatType();
+    return theBuilder.getVec2Type(floatType);
+  } else if (auto *builtinType = dyn_cast<BuiltinType>(type.getTypePtr())) {
+    switch (builtinType->getKind()) {
+    case BuiltinType::Void:
+      return theBuilder.getVoidType();
+    case BuiltinType::Float:
+      return theBuilder.getFloatType();
+    default:
+      // TODO: handle other primitive types
+      assert(false && "unhandled builtin type");
+      break;
+    }
+  } else {
+    // TODO: handle other types
+    assert(false && "unhandled clang type");
+  }
+  return 0;
+}
+
+/// \brief The class containing mappings from Clang frontend Decls to their
+/// corresponding SPIR-V <result-id>s.
+///
+/// All symbols defined in the AST should be "defined" or registered in this
+/// class and have their <result-id>s queried from this class. In the process
+/// of defining a Decl, the SPIR-V module builder passed into the constructor
+/// will be used to generate all SPIR-V instructions required.
+///
+/// This class acts as a middle layer to handle the mapping between HLSL
+/// semantics and Vulkan interface (stage builtin/input/output) variables.
+/// Such mapping is required because of the semantic differences between DirectX
+/// and Vulkan and the essence of HLSL as the front-end language for DirectX.
+/// A normal variable attached with some semantic will be translated into a
+/// single interface variables if it is of non-struct type. If it is of struct
+/// type, the fields with attached semantics will need to be translated into
+/// interface variables per Vulkan's requirements.
+///
+/// In the following class, we call a Decl or symbol as *remapped* when it is
+/// translated into an interface variable; otherwise, we call it as *normal*.
+class DeclResultIdMapper {
+public:
+  DeclResultIdMapper(spirv::ModuleBuilder *builder) : theBuilder(*builder) {}
+
+  /// \brief Defines a function return value in this mapper and returns the
+  /// <result-id> for the final return type.
+  ///
+  /// The final return type is the "residual" type after "stripping" all
+  /// subtypes with attached semantics. For exmaple, stripping "float4 :
+  /// SV_Target" will result in "void", and stripping "struct { float4 :
+  /// SV_Target, float4 }" will result in "struct { float4 }".
+  ///
+  /// Proper SPIR-V instructions will be generated to create the corresponding
+  /// interface variable if stripping happens.
+  uint32_t defineFnReturn(FunctionDecl *funcDecl) {
+    // SemanticDecl for the return value is attached to the FunctionDecl.
+    const auto sk = getInterfaceVarSemanticAndKind(funcDecl);
+    if (sk.second != InterfaceVariableKind::None) {
+      // Found return value with semantic attached. This means we need to map
+      // the return value to a single interface variable.
+      const uint32_t retTypeId =
+          translateType(funcDecl->getReturnType(), theBuilder);
+      // TODO: Change to the correct interface variable kind here.
+      const uint32_t varId = theBuilder.addStageIOVariable(
+          retTypeId, spv::StorageClass::Output, llvm::None);
+
+      stageOutputs.push_back(std::make_pair(varId, sk.first));
+      remappedDecls[funcDecl] = varId;
+      interfaceVars.insert(varId);
+
+      return theBuilder.getVoidType();
+    } else {
+      // TODO: We need to handle struct return types here.
+      return translateType(funcDecl->getReturnType(), theBuilder);
+    }
+  }
+
+  /// \brief Defines a function parameter in this mapper and returns the
+  /// <result-id> for the final parameter type. Returns 0 if the final type
+  /// is void.
+  ///
+  /// The final parameter type is the "residual" type after "stripping" all
+  /// subtypes will attached semantics. For exmaple, stripping "float4 :
+  /// SV_Target" will result in "void", and stripping "struct { float4 :
+  /// SV_Target, float4 }" will result in "struct { float4 }".
+  ///
+  /// Proper SPIR-V instructions will be generated to create the corresponding
+  /// interface variable if stripping happens.
+  uint32_t defineFnParam(ParmVarDecl *paramDecl) {
+    const auto sk = getInterfaceVarSemanticAndKind(paramDecl);
+    if (sk.second != InterfaceVariableKind::None) {
+      // Found parameter with semantic attached. This means we need to map the
+      // parameter to a single interface variable.
+      const uint32_t paramTypeId =
+          translateType(paramDecl->getType(), theBuilder);
+      // TODO: Change to the correct interface variable kind here.
+      const uint32_t varId = theBuilder.addStageIOVariable(
+          paramTypeId, spv::StorageClass::Input, llvm::None);
+
+      stageInputs.push_back(std::make_pair(varId, sk.first));
+      remappedDecls[paramDecl] = varId;
+      interfaceVars.insert(varId);
+
+      return 0;
+    } else {
+      // TODO: We need to handle struct parameter types here.
+      return translateType(paramDecl->getType(), theBuilder);
+    }
+  }
+
+  /// \brief Registers a Decl's <result-id> without generating any SPIR-V
+  /// instruction.
+  void registerDeclResultId(const NamedDecl *symbol, uint32_t resultId) {
+    normalDecls[symbol] = resultId;
+  }
+
+  /// \brief Returns true if the given <result-id> is for an interface variable.
+  bool isInterfaceVariable(uint32_t varId) const {
+    return interfaceVars.count(varId) != 0;
+  }
+
+  /// \brief Returns the <result-id> for the given Decl.
+  uint32_t getDeclResultId(const NamedDecl *decl) const {
+    if (const uint32_t id = getRemappedDeclResultId(decl))
+      return id;
+    if (const uint32_t id = getNormalDeclResultId(decl))
+      return id;
+
+    assert(false && "found unregistered Decl in DeclResultIdMapper");
+    return 0;
+  }
+
+  /// \brief Returns the <result-id> for the given remapped Decl. Returns zero
+  /// if it is not a registered remapped Decl.
+  uint32_t getRemappedDeclResultId(const NamedDecl *decl) const {
+    auto it = remappedDecls.find(decl);
+    if (it != remappedDecls.end())
+      return it->second;
+    return 0;
+  }
+
+  /// \brief Returns the <result-id> for the given normal Decl. Returns zero if
+  /// it is not a registered normal Decl.
+  uint32_t getNormalDeclResultId(const NamedDecl *decl) const {
+    auto it = normalDecls.find(decl);
+    if (it != normalDecls.end())
+      return it->second;
+    return 0;
+  }
+
+  /// \brief Returns all defined stage input and ouput variables in this mapper.
+  std::vector<uint32_t> collectStageIOVariables() {
+    std::vector<uint32_t> stageIOVars;
+
+    for (const auto &input : stageInputs) {
+      stageIOVars.push_back(input.first);
+    }
+    for (const auto &output : stageOutputs) {
+      stageIOVars.push_back(output.first);
+    }
+
+    return stageIOVars;
+  }
+
+  /// \brief Decorates all stage input and output variables with proper
+  /// location.
+  ///
+  /// This method will writes the location assignment into the module under
+  /// construction.
+  void finalizeStageIOLocations() {
+    uint32_t nextInputLocation = 0;
+    uint32_t nextOutputLocation = 0;
+
+    // TODO: sort the variables according to some criteria first, e.g.,
+    // alphabetical order of semantic names.
+    for (const auto &input : stageInputs) {
+      theBuilder.decorateLocation(input.first, nextInputLocation++);
+    }
+    for (const auto &output : stageOutputs) {
+      theBuilder.decorateLocation(output.first, nextOutputLocation++);
+    }
+  }
+
+private:
+  /// \brief Interface variable kind.
+  ///
+  /// By interface variable, I mean all stage builtin, input, and output
+  /// variables. They participate in interface matching in Vulkan pipelines.
+  enum class InterfaceVariableKind {
+    None,   ///< Not an interface variable
+    Input,  ///< Stage input variable
+    Output, ///< Stage output variable
+    IO, ///< Interface variable that can act as both stage input or stage output
+    // TODO: builtins
+  };
+
+  using InterfaceVarIdSemanticPair = std::pair<uint32_t, llvm::StringRef>;
+
+  /// \brief Returns the interface variable's semantic and kind for the given
+  /// Decl.
+  std::pair<llvm::StringRef, InterfaceVariableKind>
+  getInterfaceVarSemanticAndKind(NamedDecl *decl) const {
+    for (auto *annotation : decl->getUnusualAnnotations()) {
+      if (auto *semantic = dyn_cast<hlsl::SemanticDecl>(annotation)) {
+        const llvm::StringRef name = semantic->SemanticName;
+        // TODO: We should check the semantic name ends with a number.
+        if (name.startswith("SV_TARGET")) {
+          return std::make_pair(name, InterfaceVariableKind::Output);
+        }
+        if (name.startswith("COLOR")) {
+          return std::make_pair(name, InterfaceVariableKind::IO);
+        }
+      }
+    }
+    return std::make_pair("", InterfaceVariableKind::None);
+  }
+
+private:
+  spirv::ModuleBuilder &theBuilder;
+
+  /// Mapping of all remapped decls to their <result-id>s.
+  llvm::DenseMap<const NamedDecl *, uint32_t> remappedDecls;
+  /// Mapping of all normal decls to their <result-id>s.
+  llvm::DenseMap<const NamedDecl *, uint32_t> normalDecls;
+  /// <result-id>s of all defined interface variables.
+  ///
+  /// We need to keep a separate list here to avoid looping through the
+  /// remappedDecls to find whether an <result-id> is for interface variable.
+  llvm::SmallSet<uint32_t, 16> interfaceVars;
+
+  /// Stage input/oupt/builtin variables and their kinds.
+  ///
+  /// We need to keep a separate list here in order to sort them at the end
+  /// of the module building.
+  llvm::SmallVector<InterfaceVarIdSemanticPair, 8> stageInputs;
+  llvm::SmallVector<InterfaceVarIdSemanticPair, 8> stageOutputs;
+  llvm::SmallVector<InterfaceVarIdSemanticPair, 8> stageBuiltins;
+};
+
 class SPIRVEmitter : public ASTConsumer {
 public:
   explicit SPIRVEmitter(CompilerInstance &ci)
-      : theCompilerInstance(ci), outStream(*ci.getOutStream()), theContext(),
-        theBuilder(&theContext) {}
+      : theCompilerInstance(ci), theContext(), theBuilder(&theContext),
+        declIdMapper(&theBuilder),
+        entryFunctionName(ci.getCodeGenOpts().HLSLEntryFunction),
+        shaderStage(getSpirvShaderStageFromHlslProfile(
+            ci.getCodeGenOpts().HLSLProfile.c_str())),
+        entryFunctionId(0), curFunction(nullptr) {}
 
   void AddRequiredCapabilitiesForExecutionModel(spv::ExecutionModel em) {
     if (em == spv::ExecutionModel::TessellationControl ||
@@ -77,8 +346,7 @@ public:
       // fragment shaders. Currently using OriginUpperLeft as default.
       theBuilder.addExecutionMode(entryPointId,
                                   spv::ExecutionMode::OriginUpperLeft, {});
-    }
-    else {
+    } else {
       // TODO: Implement logic for adding proper execution mode for other shader
       // stages. Silently skipping for now.
     }
@@ -93,14 +361,25 @@ public:
     theBuilder.setAddressingModel(spv::AddressingModel::Logical);
     theBuilder.setMemoryModel(spv::MemoryModel::GLSL450);
 
+    TranslationUnitDecl *tu = context.getTranslationUnitDecl();
+
     // Process all top level Decls.
-    for (auto *decl : context.getTranslationUnitDecl()->decls()) {
+    for (auto *decl : tu->decls()) {
       doDecl(decl);
     }
 
+    theBuilder.addEntryPoint(shaderStage, entryFunctionId, entryFunctionName,
+                             declIdMapper.collectStageIOVariables());
+
+    AddExecutionModeForEntryPoint(shaderStage, entryFunctionId);
+
+    // Add Location decorations to stage input/output variables.
+    declIdMapper.finalizeStageIOLocations();
+
     // Output the constructed module.
     std::vector<uint32_t> m = theBuilder.takeModule();
-    outStream.write(reinterpret_cast<const char *>(m.data()), m.size() * 4);
+    theCompilerInstance.getOutStream()->write(
+        reinterpret_cast<const char *>(m.data()), m.size() * 4);
   }
 
   void doDecl(Decl *decl) {
@@ -112,76 +391,122 @@ public:
   }
 
   void doFunctionDecl(FunctionDecl *decl) {
-    std::vector<uint32_t> funcParamTypeIds;
-    const uint32_t funcType = translateFunctionType(decl, &funcParamTypeIds);
-    const uint32_t retType = translateType(decl->getReturnType());
-
-    const uint32_t funcId =
-        theBuilder.beginFunction(funcType, retType, funcParamTypeIds);
-    // TODO: handle function body
-    const uint32_t entryLabel = theBuilder.bbCreate();
-    theBuilder.bbReturn(entryLabel);
-    theBuilder.endFunction();
-
-    // Add an entry point to the module if necessary
-    const std::string hlslEntryFn =
-        theCompilerInstance.getCodeGenOpts().HLSLEntryFunction;
-    if (hlslEntryFn == decl->getNameInfo().getAsString()) {
-      const spv::ExecutionModel em = getSpirvShaderStageFromHlslProfile(
-          theCompilerInstance.getCodeGenOpts().HLSLProfile.c_str());
-      // TODO: Pass correct input/output interfaces to addEntryPoint.
-      theBuilder.addEntryPoint(em, funcId, hlslEntryFn, {});
-
-      // OpExecutionMode declares an execution mode for an entry point.
-      AddExecutionModeForEntryPoint(em, funcId);
-    }
-  }
-
-  uint32_t translateFunctionType(FunctionDecl *decl,
-                                 std::vector<uint32_t> *paramTypes) {
-    const uint32_t retType = translateType(decl->getReturnType());
+    curFunction = decl;
+
+    uint32_t retType = declIdMapper.defineFnReturn(decl);
+
+    // Process function parameters. We need to strip parameters mapping to stage
+    // builtin/input/output variables and use what's left in the function type.
+    std::vector<uint32_t> residualParamTypes;
+    std::vector<ParmVarDecl *> residualParams;
+
     for (auto *param : decl->params()) {
-      paramTypes->push_back(translateType(param->getType()));
-    }
-    return theBuilder.getFunctionType(retType, *paramTypes);
-  }
-
-  uint32_t translateType(QualType type) {
-    // In AST, vector types are TypedefType of TemplateSpecializationType,
-    // which is nested deeply. So we do fast track check here.
-    const auto symbol = type.getAsString();
-    if (symbol == "float4") {
-      const uint32_t floatType = theBuilder.getFloatType();
-      return theBuilder.getVec4Type(floatType);
-    } else if (symbol == "float3") {
-      const uint32_t floatType = theBuilder.getFloatType();
-      return theBuilder.getVec3Type(floatType);
-    } else if (symbol == "float2") {
-      const uint32_t floatType = theBuilder.getFloatType();
-      return theBuilder.getVec2Type(floatType);
-    } else if (auto *builtinType = dyn_cast<BuiltinType>(type.getTypePtr())) {
-      switch (builtinType->getKind()) {
-      case BuiltinType::Void:
-        return theBuilder.getVoidType();
-      case BuiltinType::Float:
-        return theBuilder.getFloatType();
-      default:
-        // TODO: handle other primitive types
-        assert(false && "unhandled builtin type");
-        break;
+      // Get the "residual" parameter type. If nothing left after stripping,
+      // zero will be returned.
+      const uint32_t paramType = declIdMapper.defineFnParam(param);
+      if (paramType) {
+        residualParamTypes.push_back(paramType);
+        residualParams.push_back(param);
       }
+    }
+
+    const uint32_t funcType =
+        theBuilder.getFunctionType(retType, residualParamTypes);
+    const uint32_t funcId = theBuilder.beginFunction(funcType, retType);
+
+    // Register all the "residual" parameters into the mapper.
+    for (uint32_t i = 0; i < residualParams.size(); ++i) {
+      declIdMapper.registerDeclResultId(
+          residualParams[i], theBuilder.addFnParameter(residualParamTypes[i]));
+    }
+
+    if (decl->hasBody()) {
+      const uint32_t entryLabel = theBuilder.createBasicBlock();
+      theBuilder.setInsertPoint(entryLabel);
+
+      // Process all statments in the body.
+      for (Stmt *stmt : cast<CompoundStmt>(decl->getBody())->body())
+        doStmt(stmt);
+
+      // We have processed all Stmts in this function and now in the last basic
+      // block. Make sure we have OpReturn if this is a void(...) function.
+      if (retType == theBuilder.getVoidType() &&
+          !theBuilder.isCurrentBasicBlockTerminated()) {
+        theBuilder.createReturn();
+      }
+
+      theBuilder.endFunction();
+    }
+
+    // Record the entry function's <result-id>.
+    if (entryFunctionName == decl->getNameInfo().getAsString()) {
+      entryFunctionId = funcId;
+    }
+
+    curFunction = nullptr;
+  }
+
+  void doStmt(Stmt *stmt) {
+    if (auto *retStmt = dyn_cast<ReturnStmt>(stmt)) {
+      doReturnStmt(retStmt);
+    }
+    // TODO: handle other statements
+  }
+
+  void doReturnStmt(ReturnStmt *stmt) {
+    const uint32_t retValue = doExpr(stmt->getRetValue());
+    const uint32_t interfaceVarId =
+        declIdMapper.getRemappedDeclResultId(curFunction);
+
+    if (interfaceVarId) {
+      // The return value is mapped to an interface variable. We need to store
+      // the value into the interface variable instead.
+      theBuilder.createStore(interfaceVarId, retValue);
+      theBuilder.createReturn();
     } else {
-      // TODO: handle other types
-      assert(false && "unhandled clang type");
+      theBuilder.createReturnValue(retValue);
     }
+  }
+
+  uint32_t doExpr(Expr *expr) {
+    if (auto *delRefExpr = dyn_cast<DeclRefExpr>(expr)) {
+      // Returns the <result-id> of the referenced Decl.
+      const NamedDecl *referredDecl = delRefExpr->getFoundDecl();
+      assert(referredDecl && "found non-NamedDecl referenced");
+      return declIdMapper.getDeclResultId(referredDecl);
+    } else if (auto *castExpr = dyn_cast<ImplicitCastExpr>(expr)) {
+      const uint32_t fromValue = doExpr(castExpr->getSubExpr());
+      // Using lvalue as rvalue will result in a ImplicitCast in Clang AST.
+      // This place gives us a place to inject the code for handling interface
+      // variables. Since using the <result-id> of an interface variable as
+      // rvalue means OpLoad it first. For normal values, it is not required.
+      if (declIdMapper.isInterfaceVariable(fromValue)) {
+        const uint32_t resultType =
+            translateType(castExpr->getType(), theBuilder);
+        return theBuilder.createLoad(resultType, fromValue);
+      }
+      return fromValue;
+    }
+    // TODO: handle other expressions
     return 0;
   }
 
 private:
-  raw_ostream &outStream;
+  CompilerInstance &theCompilerInstance;
   spirv::SPIRVContext theContext;
   spirv::ModuleBuilder theBuilder;
-  CompilerInstance &theCompilerInstance;
+  DeclResultIdMapper declIdMapper;
+
+  /// Entry function name and shader stage. Both of them are derived from the
+  /// command line and should be const.
+  const llvm::StringRef entryFunctionName;
+  const spv::ExecutionModel shaderStage;
+
+  /// <result-id> for the entry function. Initially it is zero and will be reset
+  /// when starting to translate the entry function.
+  uint32_t entryFunctionId;
+  /// The current function under traversal.
+  const FunctionDecl *curFunction;
 };
 
 } // namespace

+ 59 - 21
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -24,9 +24,7 @@ ModuleBuilder::ModuleBuilder(SPIRVContext *C)
   });
 }
 
-uint32_t
-ModuleBuilder::beginFunction(uint32_t funcType, uint32_t returnType,
-                             const std::vector<uint32_t> &paramTypeIds) {
+uint32_t ModuleBuilder::beginFunction(uint32_t funcType, uint32_t returnType) {
   if (theFunction) {
     assert(false && "found nested function");
     return 0;
@@ -37,15 +35,17 @@ ModuleBuilder::beginFunction(uint32_t funcType, uint32_t returnType,
   theFunction = llvm::make_unique<Function>(
       returnType, fId, spv::FunctionControlMask::MaskNone, funcType);
 
-  // Any OpFunction must be immediately followed by one OpFunctionParameter
-  // instruction per each formal parameter of the function.
-  for (const auto &typeId : paramTypeIds) {
-    theFunction->addParameter(typeId, theContext.takeNextId());
-  }
-
   return fId;
 }
 
+uint32_t ModuleBuilder::addFnParameter(uint32_t type) {
+  assert(theFunction && "found detached parameter");
+
+  const uint32_t paramId = theContext.takeNextId();
+  theFunction->addParameter(type, paramId);
+  return paramId;
+}
+
 bool ModuleBuilder::endFunction() {
   if (theFunction == nullptr) {
     assert(false && "no active function");
@@ -68,7 +68,7 @@ bool ModuleBuilder::endFunction() {
   return true;
 }
 
-uint32_t ModuleBuilder::bbCreate() {
+uint32_t ModuleBuilder::createBasicBlock() {
   if (theFunction == nullptr) {
     assert(false && "found detached basic block");
     return 0;
@@ -80,26 +80,40 @@ uint32_t ModuleBuilder::bbCreate() {
   return labelId;
 }
 
-bool ModuleBuilder::bbReturn(uint32_t labelId) {
+bool ModuleBuilder::setInsertPoint(uint32_t labelId) {
   auto it = basicBlocks.find(labelId);
   if (it == basicBlocks.end()) {
     assert(false && "invalid <label-id>");
     return false;
   }
+  insertPoint = it->second.get();
+  return true;
+}
+
+uint32_t ModuleBuilder::createLoad(uint32_t resultType, uint32_t pointer) {
+  assert(insertPoint && "null insert point");
+  const uint32_t resultId = theContext.takeNextId();
+  instBuilder.opLoad(resultType, resultId, pointer, llvm::None).x();
+  insertPoint->addInstruction(std::move(constructSite));
+  return resultId;
+}
+
+void ModuleBuilder::createStore(uint32_t address, uint32_t value) {
+  assert(insertPoint && "null insert point");
+  instBuilder.opStore(address, value, llvm::None).x();
+  insertPoint->addInstruction(std::move(constructSite));
+}
 
+void ModuleBuilder::createReturn() {
+  assert(insertPoint && "null insert point");
   instBuilder.opReturn().x();
-  it->second->addInstruction(std::move(constructSite));
-  return true;
+  insertPoint->addInstruction(std::move(constructSite));
 }
 
-bool ModuleBuilder::setInsertPoint(uint32_t labelId) {
-  auto it = basicBlocks.find(labelId);
-  if (it == basicBlocks.end()) {
-    assert(false && "invalid <label-id>");
-    return false;
-  }
-  insertPoint = it->second.get();
-  return true;
+void ModuleBuilder::createReturnValue(uint32_t value) {
+  assert(insertPoint && "null insert point");
+  instBuilder.opReturnValue(value).x();
+  insertPoint->addInstruction(std::move(constructSite));
 }
 
 uint32_t ModuleBuilder::getVoidType() {
@@ -146,6 +160,30 @@ ModuleBuilder::getFunctionType(uint32_t returnType,
   return typeId;
 }
 
+uint32_t ModuleBuilder::getPointerType(uint32_t pointeeType,
+                                       spv::StorageClass storageClass) {
+  const Type *type = Type::getPointer(theContext, storageClass, pointeeType);
+  const uint32_t typeId = theContext.getResultIdForType(type);
+  theModule.addType(type, typeId);
+  return typeId;
+}
+
+uint32_t
+ModuleBuilder::addStageIOVariable(uint32_t type, spv::StorageClass storageClass,
+                                  llvm::Optional<uint32_t> initilizer) {
+  const uint32_t pointerType = getPointerType(type, storageClass);
+  const uint32_t varId = theContext.takeNextId();
+  instBuilder.opVariable(pointerType, varId, storageClass, initilizer).x();
+  theModule.addVariable(std::move(constructSite));
+  return varId;
+}
+
+void ModuleBuilder::decorateLocation(uint32_t targetId, uint32_t location) {
+  const Decoration *d =
+      Decoration::getLocation(theContext, location, llvm::None);
+  theModule.addDecoration(*d, targetId);
+}
+
 std::vector<uint32_t> ModuleBuilder::takeModule() {
   theModule.setBound(theContext.getNextId());
 

+ 12 - 4
tools/clang/lib/SPIRV/Structure.cpp

@@ -50,8 +50,7 @@ void BasicBlock::take(InstBuilder *builder) {
   // Make sure we have a terminator instruction at the end.
   // TODO: This is a little bit ugly. It suggests that we should put the opcode
   // in the Instruction struct. But fine for now.
-  assert(!instructions.empty() && isTerminator(static_cast<spv::Op>(
-                                      instructions.back().front() & 0xffff)));
+  assert(isTerminated() && "found basic block without terminator");
   builder->opLabel(labelId).x();
   for (auto &inst : instructions) {
     builder->getConsumer()(std::move(inst));
@@ -59,6 +58,13 @@ void BasicBlock::take(InstBuilder *builder) {
   clear();
 }
 
+bool BasicBlock::isTerminated() const {
+  return !instructions.empty() &&
+         isTerminator(
+             // Take the last 16 bits and convert it into opcode
+             static_cast<spv::Op>(instructions.back().front() & 0xffff));
+}
+
 Function::Function(Function &&that)
     : resultType(that.resultType), resultId(that.resultId),
       funcControl(that.funcControl), funcType(that.funcType),
@@ -162,7 +168,7 @@ void SPIRVModule::take(InstBuilder *builder) {
   for (auto &inst : entryPoints) {
     builder
         ->opEntryPoint(inst.executionModel, inst.targetId,
-                       std::move(inst.targetName), std::move(inst.interfaces))
+                       std::move(inst.targetName), inst.interfaces)
         .x();
   }
 
@@ -194,7 +200,9 @@ void SPIRVModule::take(InstBuilder *builder) {
     consumer(std::move(c.constant));
   }
 
-  // TODO: global variables
+  for (auto &v : variables) {
+    consumer(std::move(v));
+  }
 
   for (uint32_t i = 0; i < functions.size(); ++i) {
     functions[i]->take(builder);

+ 15 - 14
tools/clang/test/CodeGenSPIRV/passthru-ps.hlsl2spv

@@ -4,28 +4,29 @@ float4 main(float4 input: COLOR): SV_TARGET
     return input;
 }
 
-
-// TODO:
-// Input/Output interfaces are missing from OpEntryPoint
-// Semantics
-// Function return value
-
-
 // CHECK-WHOLE-SPIR-V:
 // ; SPIR-V
 // ; Version: 1.0
 // ; Generator: Google spiregg; 0
-// ; Bound: 7
+// ; Bound: 12
 // ; Schema: 0
 // OpCapability Shader
 // OpMemoryModel Logical GLSL450
-// OpEntryPoint Fragment %4 "main"
-// OpExecutionMode %4 OriginUpperLeft
+// OpEntryPoint Fragment %9 "main" %7 %4
+// OpExecutionMode %9 OriginUpperLeft
+// OpDecorate %7 Location 0
+// OpDecorate %4 Location 0
 // %float = OpTypeFloat 32
 // %v4float = OpTypeVector %float 4
-// %3 = OpTypeFunction %v4float %v4float
-// %4 = OpFunction %v4float None %3
-// %5 = OpFunctionParameter %v4float
-// %6 = OpLabel
+// %_ptr_Output_v4float = OpTypePointer Output %v4float
+// %void = OpTypeVoid
+// %_ptr_Input_v4float = OpTypePointer Input %v4float
+// %8 = OpTypeFunction %void
+// %4 = OpVariable %_ptr_Output_v4float Output
+// %7 = OpVariable %_ptr_Input_v4float Input
+// %9 = OpFunction %void None %8
+// %10 = OpLabel
+// %11 = OpLoad %v4float %7
+// OpStore %4 %11
 // OpReturn
 // OpFunctionEnd

+ 12 - 8
tools/clang/test/CodeGenSPIRV/passthru-vs.hlsl2spv

@@ -33,17 +33,21 @@ float4 VSmain(float4 position: POSITION, float4 color: COLOR) {
 // ; SPIR-V
 // ; Version: 1.0
 // ; Generator: Google spiregg; 0
-// ; Bound: 8
+// ; Bound: 11
 // ; Schema: 0
 // OpCapability Shader
 // OpMemoryModel Logical GLSL450
-// OpEntryPoint Vertex %4 "VSmain"
+// OpEntryPoint Vertex %6 "VSmain" %4
+// OpDecorate %4 Location 0
 // %float = OpTypeFloat 32
 // %v4float = OpTypeVector %float 4
-// %3 = OpTypeFunction %v4float %v4float %v4float
-// %4 = OpFunction %v4float None %3
-// %5 = OpFunctionParameter %v4float
-// %6 = OpFunctionParameter %v4float
-// %7 = OpLabel
-// OpReturn
+// %_ptr_Input_v4float = OpTypePointer Input %v4float
+// %5 = OpTypeFunction %v4float %v4float
+// %void = OpTypeVoid
+// %4 = OpVariable %_ptr_Input_v4float Input
+// %6 = OpFunction %v4float None %5
+// %7 = OpFunctionParameter %v4float
+// %8 = OpLabel
+// %9 = OpLoad %v4float %4
+// OpReturnValue %9
 // OpFunctionEnd

+ 3 - 2
tools/clang/unittests/SPIRV/ModuleBuilderTest.cpp

@@ -55,9 +55,10 @@ TEST(ModuleBuilder, CreateBasicBlock) {
   const auto fId = context.getNextId();
   EXPECT_NE(0, builder.beginFunction(fType, rType));
   const auto labelId = context.getNextId();
-  const auto resultId = builder.bbCreate();
+  const auto resultId = builder.createBasicBlock();
   EXPECT_EQ(labelId, resultId);
-  EXPECT_TRUE(builder.bbReturn(resultId));
+  builder.setInsertPoint(resultId);
+  builder.createReturn();
   EXPECT_TRUE(builder.endFunction());
 
   const auto result = builder.takeModule();