Explorar o código

[SPIR-V] Add option to rename SPIR-V entry point (#4390)

Allow renaming the SPIR-V entry point name from the default HLSL entry
point name with `-fspv-entrypoint-name`.

Fixes #2972
Fixes #4356
Natalie Chouinard %!s(int64=3) %!d(string=hai) anos
pai
achega
787245f28f

+ 2 - 0
docs/SPIR-V.rst

@@ -3974,6 +3974,8 @@ codegen for Vulkan:
   SPIR-V backend. Also note that this requires the optimizer to be able to
   resolve all array accesses with constant indeces. Therefore, all loops using
   the resource arrays must be marked with ``[unroll]``.
+- ``-fspv-entrypoint-name=<name>``: Specify the SPIR-V entry point name. Defaults
+  to the HLSL entry point name.
 - ``-Wno-vk-ignored-features``: Does not emit warnings on ignored features
   resulting from no Vulkan support, e.g., cbuffer member initializer.
 

+ 2 - 0
include/dxc/Support/HLSLOptions.td

@@ -357,6 +357,8 @@ def fspv_flatten_resource_arrays: Flag<["-"], "fspv-flatten-resource-arrays">, G
   HelpText<"Flatten arrays of resources so each array element takes one binding number">;
 def fspv_reduce_load_size: Flag<["-"], "fspv-reduce-load-size">, Group<spirv_Group>, Flags<[CoreOption, DriverOption]>,
   HelpText<"Replaces loads of composite objects to reduce memory pressure for the loads">;
+def fspv_entrypoint_name_EQ : Joined<["-"], "fspv-entrypoint-name=">, Group<spirv_Group>, Flags<[CoreOption, DriverOption]>,
+  HelpText<"Specify the SPIR-V entry point name. Defaults to the HLSL entry point name.">;
 def fvk_auto_shift_bindings: Flag<["-"], "fvk-auto-shift-bindings">, Group<spirv_Group>, Flags<[CoreOption, DriverOption]>,
   HelpText<"Apply fvk-*-shift to resources without an explicit register assignment.">;
 def Wno_vk_ignored_features : Joined<["-"], "Wno-vk-ignored-features">, Group<spirv_Group>, Flags<[CoreOption, DriverOption, HelpHidden]>,

+ 1 - 0
include/dxc/Support/SPIRVOptions.h

@@ -84,6 +84,7 @@ struct SpirvCodeGenOptions {
   llvm::SmallVector<llvm::StringRef, 4> optConfig;
   std::vector<std::string> bindRegister;
   std::vector<std::string> bindGlobals;
+  std::string entrypointName;
 
   bool signaturePacking; ///< Whether signature packing is enabled or not
 

+ 3 - 0
lib/DxcSupport/HLSLOptions.cpp

@@ -1064,6 +1064,9 @@ int ReadDxcOpts(const OptTable *optionTable, unsigned flagsToInclude,
     }
   }
 
+  opts.SpirvOptions.entrypointName =
+      Args.getLastArgValue(OPT_fspv_entrypoint_name_EQ);
+
 #else
   if (Args.hasFlag(OPT_spirv, OPT_INVALID, false) ||
       Args.hasFlag(OPT_fvk_invert_y, OPT_INVALID, false) ||

+ 15 - 4
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -549,8 +549,8 @@ SpirvEmitter::SpirvEmitter(CompilerInstance &ci)
     : theCompilerInstance(ci), astContext(ci.getASTContext()),
       diags(ci.getDiagnostics()),
       spirvOptions(ci.getCodeGenOpts().SpirvOptions),
-      entryFunctionName(ci.getCodeGenOpts().HLSLEntryFunction), spvContext(),
-      featureManager(diags, spirvOptions),
+      hlslEntryFunctionName(ci.getCodeGenOpts().HLSLEntryFunction),
+      spvContext(), featureManager(diags, spirvOptions),
       spvBuilder(astContext, spvContext, spirvOptions, featureManager),
       declIdMapper(astContext, spvContext, spvBuilder, *this, featureManager,
                    spirvOptions),
@@ -686,6 +686,17 @@ SpirvEmitter::getInterfacesForEntryPoint(SpirvFunction *entryPoint) {
   return interfacesInVector;
 }
 
+llvm::StringRef SpirvEmitter::getEntryPointName(const FunctionInfo *entryInfo) {
+  llvm::StringRef entrypointName = entryInfo->funcDecl->getName();
+  // If this is the -E HLSL entrypoint and -fspv-entrypoint-name was set,
+  // rename the SPIR-V entrypoint.
+  if (entrypointName == hlslEntryFunctionName &&
+      !spirvOptions.entrypointName.empty()) {
+    return spirvOptions.entrypointName;
+  }
+  return entrypointName;
+}
+
 void SpirvEmitter::HandleTranslationUnit(ASTContext &context) {
   // Stop translating if there are errors in previous compilation stages.
   if (context.getDiagnostics().hasErrorOccurred())
@@ -709,7 +720,7 @@ void SpirvEmitter::HandleTranslationUnit(ASTContext &context) {
                                  funcDecl, /*isEntryFunction*/ false);
         }
       } else {
-        if (funcDecl->getName() == entryFunctionName) {
+        if (funcDecl->getName() == hlslEntryFunctionName) {
           addFunctionToWorkQueue(spvContext.getCurrentShaderModelKind(),
                                  funcDecl, /*isEntryFunction*/ true);
           numEntryPoints++;
@@ -750,7 +761,7 @@ void SpirvEmitter::HandleTranslationUnit(ASTContext &context) {
     assert(entryInfo->isEntryFunction);
     spvBuilder.addEntryPoint(
         SpirvUtils::getSpirvShaderStage(entryInfo->shaderModelKind),
-        entryInfo->entryFunction, entryInfo->funcDecl->getName(),
+        entryInfo->entryFunction, getEntryPointName(entryInfo),
         getInterfacesForEntryPoint(entryInfo->entryFunction));
   }
 

+ 4 - 1
tools/clang/lib/SPIRV/SpirvEmitter.h

@@ -1164,7 +1164,7 @@ private:
 
   /// \brief Entry function name, derived from the command line
   /// and should be const.
-  const llvm::StringRef entryFunctionName;
+  const llvm::StringRef hlslEntryFunctionName;
 
   /// \brief Structure to maintain record of all entry functions and any
   /// reachable functions.
@@ -1194,6 +1194,9 @@ private:
   /// A queue of FunctionInfo reachable from all the entry functions.
   std::vector<const FunctionInfo *> workQueue;
 
+  /// Get SPIR-V entrypoint name for the given FunctionInfo.
+  llvm::StringRef getEntryPointName(const FunctionInfo *entryInfo);
+
   /// <result-id> for the entry function. Initially it is zero and will be reset
   /// when starting to translate the entry function.
   SpirvFunction *entryFunction;

+ 7 - 0
tools/clang/test/CodeGenSPIRV/fspv-entrypoint-name.hlsl

@@ -0,0 +1,7 @@
+// RUN: %dxc -T ps_6_0 -E PSMain -fspv-entrypoint-name=main
+
+// CHECK: OpEntryPoint Fragment %PSMain "main" %in_var_COLOR %out_var_SV_TARGET
+float4 PSMain(float4 color : COLOR) : SV_TARGET
+{
+    return color;
+}

+ 2 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -3111,4 +3111,6 @@ float4 PSMain(float4 color : COLOR) : SV_TARGET { return color; }
   runCodeTest(code);
 }
 
+TEST_F(FileTest, RenameEntrypoint) { runFileTest("fspv-entrypoint-name.hlsl"); }
+
 } // namespace