Browse Source

[SPIR-V] Add option to rename SPIR-V entry point (#4390)

Allow renaming the SPIR-V entry point name from the default HLSL entry
point name with `-fspv-entrypoint-name`.

Fixes #2972
Fixes #4356
Natalie Chouinard 3 năm trước cách đây
mục cha
commit
787245f28f

+ 2 - 0
docs/SPIR-V.rst

@@ -3974,6 +3974,8 @@ codegen for Vulkan:
   SPIR-V backend. Also note that this requires the optimizer to be able to
   SPIR-V backend. Also note that this requires the optimizer to be able to
   resolve all array accesses with constant indeces. Therefore, all loops using
   resolve all array accesses with constant indeces. Therefore, all loops using
   the resource arrays must be marked with ``[unroll]``.
   the resource arrays must be marked with ``[unroll]``.
+- ``-fspv-entrypoint-name=<name>``: Specify the SPIR-V entry point name. Defaults
+  to the HLSL entry point name.
 - ``-Wno-vk-ignored-features``: Does not emit warnings on ignored features
 - ``-Wno-vk-ignored-features``: Does not emit warnings on ignored features
   resulting from no Vulkan support, e.g., cbuffer member initializer.
   resulting from no Vulkan support, e.g., cbuffer member initializer.
 
 

+ 2 - 0
include/dxc/Support/HLSLOptions.td

@@ -357,6 +357,8 @@ def fspv_flatten_resource_arrays: Flag<["-"], "fspv-flatten-resource-arrays">, G
   HelpText<"Flatten arrays of resources so each array element takes one binding number">;
   HelpText<"Flatten arrays of resources so each array element takes one binding number">;
 def fspv_reduce_load_size: Flag<["-"], "fspv-reduce-load-size">, Group<spirv_Group>, Flags<[CoreOption, DriverOption]>,
 def fspv_reduce_load_size: Flag<["-"], "fspv-reduce-load-size">, Group<spirv_Group>, Flags<[CoreOption, DriverOption]>,
   HelpText<"Replaces loads of composite objects to reduce memory pressure for the loads">;
   HelpText<"Replaces loads of composite objects to reduce memory pressure for the loads">;
+def fspv_entrypoint_name_EQ : Joined<["-"], "fspv-entrypoint-name=">, Group<spirv_Group>, Flags<[CoreOption, DriverOption]>,
+  HelpText<"Specify the SPIR-V entry point name. Defaults to the HLSL entry point name.">;
 def fvk_auto_shift_bindings: Flag<["-"], "fvk-auto-shift-bindings">, Group<spirv_Group>, Flags<[CoreOption, DriverOption]>,
 def fvk_auto_shift_bindings: Flag<["-"], "fvk-auto-shift-bindings">, Group<spirv_Group>, Flags<[CoreOption, DriverOption]>,
   HelpText<"Apply fvk-*-shift to resources without an explicit register assignment.">;
   HelpText<"Apply fvk-*-shift to resources without an explicit register assignment.">;
 def Wno_vk_ignored_features : Joined<["-"], "Wno-vk-ignored-features">, Group<spirv_Group>, Flags<[CoreOption, DriverOption, HelpHidden]>,
 def Wno_vk_ignored_features : Joined<["-"], "Wno-vk-ignored-features">, Group<spirv_Group>, Flags<[CoreOption, DriverOption, HelpHidden]>,

+ 1 - 0
include/dxc/Support/SPIRVOptions.h

@@ -84,6 +84,7 @@ struct SpirvCodeGenOptions {
   llvm::SmallVector<llvm::StringRef, 4> optConfig;
   llvm::SmallVector<llvm::StringRef, 4> optConfig;
   std::vector<std::string> bindRegister;
   std::vector<std::string> bindRegister;
   std::vector<std::string> bindGlobals;
   std::vector<std::string> bindGlobals;
+  std::string entrypointName;
 
 
   bool signaturePacking; ///< Whether signature packing is enabled or not
   bool signaturePacking; ///< Whether signature packing is enabled or not
 
 

+ 3 - 0
lib/DxcSupport/HLSLOptions.cpp

@@ -1064,6 +1064,9 @@ int ReadDxcOpts(const OptTable *optionTable, unsigned flagsToInclude,
     }
     }
   }
   }
 
 
+  opts.SpirvOptions.entrypointName =
+      Args.getLastArgValue(OPT_fspv_entrypoint_name_EQ);
+
 #else
 #else
   if (Args.hasFlag(OPT_spirv, OPT_INVALID, false) ||
   if (Args.hasFlag(OPT_spirv, OPT_INVALID, false) ||
       Args.hasFlag(OPT_fvk_invert_y, OPT_INVALID, false) ||
       Args.hasFlag(OPT_fvk_invert_y, OPT_INVALID, false) ||

+ 15 - 4
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -549,8 +549,8 @@ SpirvEmitter::SpirvEmitter(CompilerInstance &ci)
     : theCompilerInstance(ci), astContext(ci.getASTContext()),
     : theCompilerInstance(ci), astContext(ci.getASTContext()),
       diags(ci.getDiagnostics()),
       diags(ci.getDiagnostics()),
       spirvOptions(ci.getCodeGenOpts().SpirvOptions),
       spirvOptions(ci.getCodeGenOpts().SpirvOptions),
-      entryFunctionName(ci.getCodeGenOpts().HLSLEntryFunction), spvContext(),
-      featureManager(diags, spirvOptions),
+      hlslEntryFunctionName(ci.getCodeGenOpts().HLSLEntryFunction),
+      spvContext(), featureManager(diags, spirvOptions),
       spvBuilder(astContext, spvContext, spirvOptions, featureManager),
       spvBuilder(astContext, spvContext, spirvOptions, featureManager),
       declIdMapper(astContext, spvContext, spvBuilder, *this, featureManager,
       declIdMapper(astContext, spvContext, spvBuilder, *this, featureManager,
                    spirvOptions),
                    spirvOptions),
@@ -686,6 +686,17 @@ SpirvEmitter::getInterfacesForEntryPoint(SpirvFunction *entryPoint) {
   return interfacesInVector;
   return interfacesInVector;
 }
 }
 
 
+llvm::StringRef SpirvEmitter::getEntryPointName(const FunctionInfo *entryInfo) {
+  llvm::StringRef entrypointName = entryInfo->funcDecl->getName();
+  // If this is the -E HLSL entrypoint and -fspv-entrypoint-name was set,
+  // rename the SPIR-V entrypoint.
+  if (entrypointName == hlslEntryFunctionName &&
+      !spirvOptions.entrypointName.empty()) {
+    return spirvOptions.entrypointName;
+  }
+  return entrypointName;
+}
+
 void SpirvEmitter::HandleTranslationUnit(ASTContext &context) {
 void SpirvEmitter::HandleTranslationUnit(ASTContext &context) {
   // Stop translating if there are errors in previous compilation stages.
   // Stop translating if there are errors in previous compilation stages.
   if (context.getDiagnostics().hasErrorOccurred())
   if (context.getDiagnostics().hasErrorOccurred())
@@ -709,7 +720,7 @@ void SpirvEmitter::HandleTranslationUnit(ASTContext &context) {
                                  funcDecl, /*isEntryFunction*/ false);
                                  funcDecl, /*isEntryFunction*/ false);
         }
         }
       } else {
       } else {
-        if (funcDecl->getName() == entryFunctionName) {
+        if (funcDecl->getName() == hlslEntryFunctionName) {
           addFunctionToWorkQueue(spvContext.getCurrentShaderModelKind(),
           addFunctionToWorkQueue(spvContext.getCurrentShaderModelKind(),
                                  funcDecl, /*isEntryFunction*/ true);
                                  funcDecl, /*isEntryFunction*/ true);
           numEntryPoints++;
           numEntryPoints++;
@@ -750,7 +761,7 @@ void SpirvEmitter::HandleTranslationUnit(ASTContext &context) {
     assert(entryInfo->isEntryFunction);
     assert(entryInfo->isEntryFunction);
     spvBuilder.addEntryPoint(
     spvBuilder.addEntryPoint(
         SpirvUtils::getSpirvShaderStage(entryInfo->shaderModelKind),
         SpirvUtils::getSpirvShaderStage(entryInfo->shaderModelKind),
-        entryInfo->entryFunction, entryInfo->funcDecl->getName(),
+        entryInfo->entryFunction, getEntryPointName(entryInfo),
         getInterfacesForEntryPoint(entryInfo->entryFunction));
         getInterfacesForEntryPoint(entryInfo->entryFunction));
   }
   }
 
 

+ 4 - 1
tools/clang/lib/SPIRV/SpirvEmitter.h

@@ -1164,7 +1164,7 @@ private:
 
 
   /// \brief Entry function name, derived from the command line
   /// \brief Entry function name, derived from the command line
   /// and should be const.
   /// and should be const.
-  const llvm::StringRef entryFunctionName;
+  const llvm::StringRef hlslEntryFunctionName;
 
 
   /// \brief Structure to maintain record of all entry functions and any
   /// \brief Structure to maintain record of all entry functions and any
   /// reachable functions.
   /// reachable functions.
@@ -1194,6 +1194,9 @@ private:
   /// A queue of FunctionInfo reachable from all the entry functions.
   /// A queue of FunctionInfo reachable from all the entry functions.
   std::vector<const FunctionInfo *> workQueue;
   std::vector<const FunctionInfo *> workQueue;
 
 
+  /// Get SPIR-V entrypoint name for the given FunctionInfo.
+  llvm::StringRef getEntryPointName(const FunctionInfo *entryInfo);
+
   /// <result-id> for the entry function. Initially it is zero and will be reset
   /// <result-id> for the entry function. Initially it is zero and will be reset
   /// when starting to translate the entry function.
   /// when starting to translate the entry function.
   SpirvFunction *entryFunction;
   SpirvFunction *entryFunction;

+ 7 - 0
tools/clang/test/CodeGenSPIRV/fspv-entrypoint-name.hlsl

@@ -0,0 +1,7 @@
+// RUN: %dxc -T ps_6_0 -E PSMain -fspv-entrypoint-name=main
+
+// CHECK: OpEntryPoint Fragment %PSMain "main" %in_var_COLOR %out_var_SV_TARGET
+float4 PSMain(float4 color : COLOR) : SV_TARGET
+{
+    return color;
+}

+ 2 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -3111,4 +3111,6 @@ float4 PSMain(float4 color : COLOR) : SV_TARGET { return color; }
   runCodeTest(code);
   runCodeTest(code);
 }
 }
 
 
+TEST_F(FileTest, RenameEntrypoint) { runFileTest("fspv-entrypoint-name.hlsl"); }
+
 } // namespace
 } // namespace