Explorar el Código

[spirv] CodeGen for EntryPoint and FnParam (#434)

* [spirv] CodeGen for OpEntryPoint

OpEntryPoint currently doesn't handle input/output interfaces.
OpExecutionMode declares an execution mode for an entry point.
The code currently uses OriginUpperLeft as default ExecutionMode.

TODO:
* Implement input/output interfaces for OpEntryPoint.
* Implement the logic to determine the execution mode properly based
on shader stage, primitive type, semantics, etc.

* [spirv] CodeGen for OpFunctionParameter.
Ehsan hace 8 años
padre
commit
d8f82f1e4d

+ 32 - 3
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -32,7 +32,9 @@ public:
   /// \brief Begins building a SPIR-V function. At any time, there can only
   /// exist at most one function under building. Returns the <result-id> for the
   /// function on success. Returns zero on failure.
-  uint32_t beginFunction(uint32_t funcType, uint32_t returnType);
+  uint32_t beginFunction(uint32_t funcType, uint32_t returnType,
+                         const std::vector<uint32_t> &paramTypeIds = {});
+
   /// \brief Ends building of the current function. Returns true of success,
   /// false on failure. All basic blocks constructed from the beginning or
   /// after ending the previous function will be collected into this function.
@@ -41,6 +43,7 @@ public:
   /// \brief Creates a SPIR-V basic block. On success, returns the <label-id>
   /// for the basic block. On failure, returns zero.
   uint32_t bbCreate();
+
   /// \brief Ends building the SPIR-V basic block having the given <label-id>
   /// with OpReturn. Returns true on success, false on failure.
   bool bbReturn(uint32_t labelId);
@@ -50,10 +53,19 @@ public:
   bool setInsertPoint(uint32_t labelId);
 
   inline void requireCapability(spv::Capability);
-
   inline void setAddressingModel(spv::AddressingModel);
   inline void setMemoryModel(spv::MemoryModel);
 
+  /// \brief Adds an Entry Point for the module under construction. We only
+  /// support a single entry point per module for now.
+  inline void addEntryPoint(spv::ExecutionModel em, uint32_t targetId,
+                            std::string targetName,
+                            std::initializer_list<uint32_t> interfaces);
+
+  /// \brief Adds an Execution Mode to the module under construction.
+  inline void addExecutionMode(uint32_t entryPointId, spv::ExecutionMode em,
+                               const std::vector<uint32_t> &params);
+
   uint32_t getVoidType();
   uint32_t getFloatType();
   uint32_t getVec2Type(uint32_t elemType);
@@ -93,7 +105,24 @@ void ModuleBuilder::requireCapability(spv::Capability cap) {
   theModule.addCapability(cap);
 }
 
+void ModuleBuilder::addEntryPoint(spv::ExecutionModel em, uint32_t targetId,
+                                  std::string targetName,
+                                  std::initializer_list<uint32_t> interfaces) {
+  theModule.addEntryPoint(em, targetId, targetName, interfaces);
+}
+
+void ModuleBuilder::addExecutionMode(uint32_t entryPointId,
+                                     spv::ExecutionMode em,
+                                     const std::vector<uint32_t> &params) {
+  instBuilder.opExecutionMode(entryPointId, em);
+  for (const auto &param : params) {
+    instBuilder.literalInteger(param);
+  }
+  instBuilder.x();
+  theModule.addExecutionMode(std::move(constructSite));
+}
+
 } // end namespace spirv
 } // end namespace clang
 
-#endif
+#endif

+ 94 - 91
tools/clang/include/clang/SPIRV/Structure.h

@@ -89,6 +89,7 @@ public:
 
   /// \brief Returns true if this function is empty.
   inline bool isEmpty() const;
+
   /// \brief Clears all paramters and basic blocks and turns this function into
   /// an empty function.
   void clear();
@@ -99,6 +100,7 @@ public:
 
   /// \brief Adds a parameter to this function.
   inline void addParameter(uint32_t paramResultType, uint32_t paramResultId);
+
   /// \brief Adds a basic block to this function.
   inline void addBasicBlock(std::unique_ptr<BasicBlock> block);
 
@@ -112,6 +114,74 @@ private:
   std::vector<std::unique_ptr<BasicBlock>> blocks;
 };
 
+/// \brief The struct representing a SPIR-V module header.
+struct Header {
+  /// \brief Default constructs a SPIR-V module header with id bound 0.
+  Header();
+
+  /// \brief Feeds the consumer with all the SPIR-V words for this header.
+  void collect(const WordConsumer &consumer);
+
+  const uint32_t magicNumber;
+  const uint32_t version;
+  const uint32_t generator;
+  uint32_t bound;
+  const uint32_t reserved;
+};
+
+/// \brief The struct representing an extended instruction set.
+struct ExtInstSet {
+  inline ExtInstSet(uint32_t id, std::string name);
+
+  const uint32_t resultId;
+  const std::string setName;
+};
+
+/// \brief The struct representing an entry point.
+struct EntryPoint {
+  inline EntryPoint(spv::ExecutionModel, uint32_t id, std::string name,
+                    std::initializer_list<uint32_t> interface);
+
+  const spv::ExecutionModel executionModel;
+  const uint32_t targetId;
+  const std::string targetName;
+  const std::initializer_list<uint32_t> interfaces;
+};
+
+/// \brief The struct representing a debug name.
+struct DebugName {
+  inline DebugName(uint32_t id, llvm::Optional<uint32_t> index,
+                   std::string targetName);
+
+  const uint32_t targetId;
+  const llvm::Optional<uint32_t> memberIndex;
+  const std::string name;
+};
+
+/// \brief The struct representing a deocoration and its target <result-id>.
+struct DecorationIdPair {
+  inline DecorationIdPair(const Decoration &decor, uint32_t id);
+
+  const Decoration &decoration;
+  const uint32_t targetId;
+};
+
+/// \brief The struct representing a type and its <result-id>.
+struct TypeIdPair {
+  inline TypeIdPair(const Type &ty, uint32_t id);
+
+  const Type &type;
+  const uint32_t resultId;
+};
+
+/// \brief The struct representing a constant and its type.
+struct Constant {
+  inline Constant(const Type &ty, Instruction &&value);
+
+  const Type &type;
+  Instruction constant;
+};
+
 /// \brief The class representing a SPIR-V module.
 class SPIRVModule {
 public:
@@ -160,73 +230,6 @@ public:
   inline void addFunction(std::unique_ptr<Function>);
 
 private:
-  /// \brief The struct representing a SPIR-V module header.
-  struct Header {
-    /// \brief Default constructs a SPIR-V module header with id bound 0.
-    Header();
-
-    /// \brief Feeds the consumer with all the SPIR-V words for this header.
-    void collect(const WordConsumer &consumer);
-
-    const uint32_t magicNumber;
-    const uint32_t version;
-    const uint32_t generator;
-    uint32_t bound;
-    const uint32_t reserved;
-  };
-
-  /// \brief The struct representing an extended instruction set.
-  struct ExtInstSet {
-    inline ExtInstSet(uint32_t id, std::string name);
-
-    const uint32_t resultId;
-    const std::string setName;
-  };
-
-  /// \brief The struct representing an entry point.
-  struct EntryPoint {
-    inline EntryPoint(spv::ExecutionModel, uint32_t id, std::string name,
-                      std::initializer_list<uint32_t> interface);
-
-    const spv::ExecutionModel executionModel;
-    const uint32_t targetId;
-    const std::string targetName;
-    const std::initializer_list<uint32_t> interfaces;
-  };
-
-  /// \brief The struct representing a debug name.
-  struct DebugName {
-    inline DebugName(uint32_t id, llvm::Optional<uint32_t> index,
-                     std::string targetName);
-
-    const uint32_t targetId;
-    const llvm::Optional<uint32_t> memberIndex;
-    const std::string name;
-  };
-
-  /// \brief The struct representing a deocoration and its target <result-id>.
-  struct DecorationIdPair {
-    inline DecorationIdPair(const Decoration &decor, uint32_t id);
-
-    const Decoration &decoration;
-    const uint32_t targetId;
-  };
-
-  /// \brief The struct representing a type and its <result-id>.
-  struct TypeIdPair {
-    inline TypeIdPair(const Type &ty, uint32_t id);
-
-    const Type &type;
-    const uint32_t resultId;
-  };
-
-  /// \brief The struct representing a constant and its type.
-  struct Constant {
-    inline Constant(const Type &ty, Instruction &&value);
-    const Type &type;
-    Instruction constant;
-  };
-
   Header header; ///< SPIR-V module header.
   std::vector<spv::Capability> capabilities;
   std::vector<std::string> extensions;
@@ -288,6 +291,29 @@ void Function::addBasicBlock(std::unique_ptr<BasicBlock> block) {
   blocks.push_back(std::move(block));
 }
 
+ExtInstSet::ExtInstSet(uint32_t id, std::string name)
+    : resultId(id), setName(name) {}
+
+EntryPoint::EntryPoint(spv::ExecutionModel em, uint32_t id,
+                                    std::string name,
+                                    std::initializer_list<uint32_t> interface)
+    : executionModel(em), targetId(id), targetName(std::move(name)),
+      interfaces(std::move(interface)) {}
+
+DebugName::DebugName(uint32_t id, llvm::Optional<uint32_t> index,
+                                  std::string targetName)
+    : targetId(id), memberIndex(index), name(std::move(targetName)) {}
+
+DecorationIdPair::DecorationIdPair(const Decoration &decor,
+                                                uint32_t id)
+    : decoration(decor), targetId(id) {}
+
+TypeIdPair::TypeIdPair(const Type &ty, uint32_t id)
+    : type(ty), resultId(id) {}
+
+Constant::Constant(const Type &ty, Instruction &&value)
+    : type(ty), constant(std::move(value)) {}
+
 SPIRVModule::SPIRVModule()
     : addressingModel(llvm::None), memoryModel(llvm::None) {}
 
@@ -336,30 +362,7 @@ void SPIRVModule::addFunction(std::unique_ptr<Function> f) {
   functions.push_back(std::move(f));
 }
 
-SPIRVModule::ExtInstSet::ExtInstSet(uint32_t id, std::string name)
-    : resultId(id), setName(name) {}
-
-SPIRVModule::EntryPoint::EntryPoint(spv::ExecutionModel em, uint32_t id,
-                                    std::string name,
-                                    std::initializer_list<uint32_t> interface)
-    : executionModel(em), targetId(id), targetName(std::move(name)),
-      interfaces(std::move(interface)) {}
-
-SPIRVModule::DebugName::DebugName(uint32_t id, llvm::Optional<uint32_t> index,
-                                  std::string targetName)
-    : targetId(id), memberIndex(index), name(std::move(targetName)) {}
-
-SPIRVModule::DecorationIdPair::DecorationIdPair(const Decoration &decor,
-                                                uint32_t id)
-    : decoration(decor), targetId(id) {}
-
-SPIRVModule::TypeIdPair::TypeIdPair(const Type &ty, uint32_t id)
-    : type(ty), resultId(id) {}
-
-SPIRVModule::Constant::Constant(const Type &ty, Instruction &&value)
-    : type(ty), constant(std::move(value)) {}
-
 } // end namespace spirv
 } // end namespace clang
 
-#endif
+#endif

+ 84 - 12
tools/clang/lib/SPIRV/EmitSPIRVAction.cpp

@@ -21,16 +21,73 @@
 #include "llvm/Support/Path.h"
 #include "llvm/Support/raw_ostream.h"
 
+namespace {
+spv::ExecutionModel getSpirvShaderStageFromHlslProfile(const char *profile) {
+  // DXIL Models are:
+  // Profile (DXIL Model) : HLSL Shader Kind : SPIR-V Shader Kind
+  // vs_<version>         : Vertex Shader    : Vertex Shader
+  // hs_<version>         : Hull Shader      : Tassellation Control Shader
+  // ds_<version>         : Domain Shader    : Tessellation Evaluation Shader
+  // gs_<version>         : Geometry Shader  : Geometry Shader
+  // ps_<version>         : Pixel Shader     : Fragment Shader
+  // cs_<version>         : Compute Shader   : Compute Shader
+  switch (profile[0]) {
+  case 'v': return spv::ExecutionModel::Vertex;
+  case 'h': return spv::ExecutionModel::TessellationControl;
+  case 'd': return spv::ExecutionModel::TessellationEvaluation;
+  case 'g': return spv::ExecutionModel::Geometry;
+  case 'p': return spv::ExecutionModel::Fragment;
+  case 'c': return spv::ExecutionModel::GLCompute;
+  default:
+    assert(false && "Unknown HLSL Profile");
+    return spv::ExecutionModel::Fragment;
+  }
+}
+
+} // namespace
+
 namespace clang {
 namespace {
 
 class SPIRVEmitter : public ASTConsumer {
 public:
-  explicit SPIRVEmitter(raw_ostream *out)
-      : outStream(*out), theContext(), theBuilder(&theContext) {}
+  explicit SPIRVEmitter(CompilerInstance &ci)
+      : theCompilerInstance(ci), outStream(*ci.getOutStream()), theContext(),
+        theBuilder(&theContext) {}
+
+  void AddRequiredCapabilitiesForExecutionModel(spv::ExecutionModel em) {
+    if (em == spv::ExecutionModel::TessellationControl ||
+        em == spv::ExecutionModel::TessellationEvaluation) {
+      theBuilder.requireCapability(spv::Capability::Tessellation);
+      assert(false && "Tasselation Shaders are currently not supported.");
+    } else if (em == spv::ExecutionModel::Geometry) {
+      theBuilder.requireCapability(spv::Capability::Geometry);
+      assert(false && "Geometry Shaders are currently not supported.");
+    } else {
+      theBuilder.requireCapability(spv::Capability::Shader);
+    }
+  }
+
+  /// \brief Adds the execution mode for the given entry point based on the
+  /// execution model.
+  void AddExecutionModeForEntryPoint(spv::ExecutionModel execModel,
+                                     uint32_t entryPointId) {
+    if (execModel == spv::ExecutionModel::Fragment) {
+      // TODO: Implement the logic to determine the proper Execution Mode for
+      // fragment shaders. Currently using OriginUpperLeft as default.
+      theBuilder.addExecutionMode(entryPointId,
+                                  spv::ExecutionMode::OriginUpperLeft, {});
+    }
+    else {
+      // TODO: Implement logic for adding proper execution mode for other shader
+      // stages. Silently skipping for now.
+    }
+  }
 
   void HandleTranslationUnit(ASTContext &context) override {
-    theBuilder.requireCapability(spv::Capability::Shader);
+    const spv::ExecutionModel em = getSpirvShaderStageFromHlslProfile(
+        theCompilerInstance.getCodeGenOpts().HLSLProfile.c_str());
+    AddRequiredCapabilitiesForExecutionModel(em);
 
     // Addressing and memory model are required in a valid SPIR-V module.
     theBuilder.setAddressingModel(spv::AddressingModel::Logical);
@@ -55,24 +112,38 @@ public:
   }
 
   void doFunctionDecl(FunctionDecl *decl) {
-    const uint32_t funcType = translateFunctionType(decl);
+    std::vector<uint32_t> funcParamTypeIds;
+    const uint32_t funcType = translateFunctionType(decl, &funcParamTypeIds);
     const uint32_t retType = translateType(decl->getReturnType());
 
-    theBuilder.beginFunction(funcType, retType);
-    // TODO: handle function parameters
+    const uint32_t funcId =
+        theBuilder.beginFunction(funcType, retType, funcParamTypeIds);
     // TODO: handle function body
     const uint32_t entryLabel = theBuilder.bbCreate();
     theBuilder.bbReturn(entryLabel);
     theBuilder.endFunction();
+
+    // Add an entry point to the module if necessary
+    const std::string hlslEntryFn =
+        theCompilerInstance.getCodeGenOpts().HLSLEntryFunction;
+    if (hlslEntryFn == decl->getNameInfo().getAsString()) {
+      const spv::ExecutionModel em = getSpirvShaderStageFromHlslProfile(
+          theCompilerInstance.getCodeGenOpts().HLSLProfile.c_str());
+      // TODO: Pass correct input/output interfaces to addEntryPoint.
+      theBuilder.addEntryPoint(em, funcId, hlslEntryFn, {});
+
+      // OpExecutionMode declares an execution mode for an entry point.
+      AddExecutionModeForEntryPoint(em, funcId);
+    }
   }
 
-  uint32_t translateFunctionType(FunctionDecl *decl) {
+  uint32_t translateFunctionType(FunctionDecl *decl,
+                                 std::vector<uint32_t> *paramTypes) {
     const uint32_t retType = translateType(decl->getReturnType());
-    std::vector<uint32_t> paramTypes;
     for (auto *param : decl->params()) {
-      paramTypes.push_back(translateType(param->getType()));
+      paramTypes->push_back(translateType(param->getType()));
     }
-    return theBuilder.getFunctionType(retType, paramTypes);
+    return theBuilder.getFunctionType(retType, *paramTypes);
   }
 
   uint32_t translateType(QualType type) {
@@ -110,12 +181,13 @@ private:
   raw_ostream &outStream;
   spirv::SPIRVContext theContext;
   spirv::ModuleBuilder theBuilder;
+  CompilerInstance &theCompilerInstance;
 };
 
 } // namespace
 
 std::unique_ptr<ASTConsumer>
 EmitSPIRVAction::CreateASTConsumer(CompilerInstance &CI, StringRef InFile) {
-  return llvm::make_unique<SPIRVEmitter>(CI.getOutStream());
+  return llvm::make_unique<SPIRVEmitter>(CI);
 }
-} // end namespace clang
+} // end namespace clang

+ 9 - 1
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -24,7 +24,9 @@ ModuleBuilder::ModuleBuilder(SPIRVContext *C)
   });
 }
 
-uint32_t ModuleBuilder::beginFunction(uint32_t funcType, uint32_t returnType) {
+uint32_t
+ModuleBuilder::beginFunction(uint32_t funcType, uint32_t returnType,
+                             const std::vector<uint32_t> &paramTypeIds) {
   if (theFunction) {
     assert(false && "found nested function");
     return 0;
@@ -35,6 +37,12 @@ uint32_t ModuleBuilder::beginFunction(uint32_t funcType, uint32_t returnType) {
   theFunction = llvm::make_unique<Function>(
       returnType, fId, spv::FunctionControlMask::MaskNone, funcType);
 
+  // Any OpFunction must be immediately followed by one OpFunctionParameter
+  // instruction per each formal parameter of the function.
+  for (const auto &typeId : paramTypeIds) {
+    theFunction->addParameter(typeId, theContext.takeNextId());
+  }
+
   return fId;
 }
 

+ 11 - 11
tools/clang/lib/SPIRV/Structure.cpp

@@ -100,11 +100,21 @@ void Function::take(InstBuilder *builder) {
   clear();
 }
 
-SPIRVModule::Header::Header()
+Header::Header()
     : magicNumber(spv::MagicNumber), version(spv::Version),
       generator((kGeneratorNumber << 16) | kToolVersion), bound(0),
       reserved(0) {}
 
+void Header::collect(const WordConsumer &consumer) {
+  std::vector<uint32_t> words;
+  words.push_back(magicNumber);
+  words.push_back(version);
+  words.push_back(generator);
+  words.push_back(bound);
+  words.push_back(reserved);
+  consumer(std::move(words));
+}
+
 bool SPIRVModule::isEmpty() const {
   return header.bound == 0 && capabilities.empty() && extensions.empty() &&
          extInstSets.empty() && !addressingModel.hasValue() &&
@@ -193,15 +203,5 @@ void SPIRVModule::take(InstBuilder *builder) {
   clear();
 }
 
-void SPIRVModule::Header::collect(const WordConsumer &consumer) {
-  std::vector<uint32_t> words;
-  words.push_back(magicNumber);
-  words.push_back(version);
-  words.push_back(generator);
-  words.push_back(bound);
-  words.push_back(reserved);
-  consumer(std::move(words));
-}
-
 } // end namespace spirv
 } // end namespace clang

+ 2 - 6
tools/clang/test/CodeGenSPIRV/empty-void-main.hlsl2spv

@@ -4,12 +4,6 @@ void main()
 
 }
 
-
-// TODO:
-// OpEntryPoint Fragment %main "main"
-// OpExecutionMode %main OriginUpperLeft
-
-
 // CHECK-WHOLE-SPIR-V:
 // ; SPIR-V
 // ; Version: 1.0
@@ -18,6 +12,8 @@ void main()
 // ; Schema: 0
 // OpCapability Shader
 // OpMemoryModel Logical GLSL450
+// OpEntryPoint Fragment %3 "main"
+// OpExecutionMode %3 OriginUpperLeft
 // %void = OpTypeVoid
 // %2 = OpTypeFunction %void
 // %3 = OpFunction %void None %2

+ 6 - 5
tools/clang/test/CodeGenSPIRV/passthru-ps.hlsl2spv

@@ -6,10 +6,8 @@ float4 main(float4 input: COLOR): SV_TARGET
 
 
 // TODO:
-// OpEntryPoint Fragment %main "main"
-// OpExecutionMode %main OriginUpperLeft
+// Input/Output interfaces are missing from OpEntryPoint
 // Semantics
-// Function parameter
 // Function return value
 
 
@@ -17,14 +15,17 @@ float4 main(float4 input: COLOR): SV_TARGET
 // ; SPIR-V
 // ; Version: 1.0
 // ; Generator: Google spiregg; 0
-// ; Bound: 6
+// ; Bound: 7
 // ; Schema: 0
 // OpCapability Shader
 // OpMemoryModel Logical GLSL450
+// OpEntryPoint Fragment %4 "main"
+// OpExecutionMode %4 OriginUpperLeft
 // %float = OpTypeFloat 32
 // %v4float = OpTypeVector %float 4
 // %3 = OpTypeFunction %v4float %v4float
 // %4 = OpFunction %v4float None %3
-// %5 = OpLabel
+// %5 = OpFunctionParameter %v4float
+// %6 = OpLabel
 // OpReturn
 // OpFunctionEnd

+ 49 - 0
tools/clang/test/CodeGenSPIRV/passthru-vs.hlsl2spv

@@ -0,0 +1,49 @@
+// Run: %dxc -T vs_6_0 -E VSmain
+
+/*
+// This is the proper way to do it, but our code doesn't handle PSInput
+// as return type for now.
+
+struct PSInput {
+  float4 position : POSITION;
+  float4 color : COLOR;
+};
+
+PSInput VSmain(float4 position: POSITION, float4 color: COLOR) {
+  PSInput result;
+  result.position = position;
+  result.color = color;
+  return result;
+}
+*/
+
+float4 VSmain(float4 position: POSITION, float4 color: COLOR) {
+  return color;
+}
+
+// TODO:
+// Proper representation for the structure.
+// Input/Output interfaces for OpEntryPoint
+// Proper logic to determine ExecutionMode
+// Semantics
+// Function return value
+
+
+// CHECK-WHOLE-SPIR-V:
+// ; SPIR-V
+// ; Version: 1.0
+// ; Generator: Google spiregg; 0
+// ; Bound: 8
+// ; Schema: 0
+// OpCapability Shader
+// OpMemoryModel Logical GLSL450
+// OpEntryPoint Vertex %4 "VSmain"
+// %float = OpTypeFloat 32
+// %v4float = OpTypeVector %float 4
+// %3 = OpTypeFunction %v4float %v4float %v4float
+// %4 = OpFunction %v4float None %3
+// %5 = OpFunctionParameter %v4float
+// %6 = OpFunctionParameter %v4float
+// %7 = OpLabel
+// OpReturn
+// OpFunctionEnd

+ 6 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -23,3 +23,9 @@ TEST_F(WholeFileTest, PassThruPixelShader) {
                    /*generateHeader*/ true,
                    /*runValidation*/ false);
 }
+
+TEST_F(WholeFileTest, PassThruVertexShader) {
+  runWholeFileTest("passthru-vs.hlsl2spv",
+                   /*generateHeader*/ true,
+                   /*runValidation*/ false);
+}