Browse Source

[SPIR-V] Add noinline support for SPIR-V generation (#3163)

This PR will add `DontInline` function control flag to the OpFunction in generated SPIR-V if the function has `[noinline]` in HLSL shader.

This is the first step of work for #3158.
After the `DontInline` flag is added, spirv-opt needs an option to control whether it should honor the flag or not (it ignores the flag currently), and the option should be passed down by DXC.
Junda Liu 4 years ago
parent
commit
e39d1397ca

+ 3 - 1
tools/clang/include/clang/SPIRV/SpirvBuilder.h

@@ -68,7 +68,8 @@ public:
   /// \brief Creates a SpirvFunction object with the given information and adds
   /// it to list of all discovered functions in the SpirvModule.
   SpirvFunction *createSpirvFunction(QualType returnType, SourceLocation,
-                                     llvm::StringRef name, bool isPrecise);
+                                     llvm::StringRef name, bool isPrecise,
+                                     bool isNoInline = false);
 
   /// \brief Begins building a SPIR-V function by allocating a SpirvFunction
   /// object. Returns the pointer for the function on success. Returns nullptr
@@ -78,6 +79,7 @@ public:
   SpirvFunction *beginFunction(QualType returnType, SourceLocation,
                                llvm::StringRef name = "",
                                bool isPrecise = false,
+                               bool isNoInline = false,
                                SpirvFunction *func = nullptr);
 
   /// \brief Creates and registers a function parameter of the given pointer

+ 7 - 1
tools/clang/include/clang/SPIRV/SpirvFunction.h

@@ -25,7 +25,8 @@ class SpirvVisitor;
 class SpirvFunction {
 public:
   SpirvFunction(QualType astReturnType, SourceLocation,
-                llvm::StringRef name = "", bool precise = false);
+                llvm::StringRef name = "", bool precise = false,
+                bool noInline = false);
 
   ~SpirvFunction();
 
@@ -70,8 +71,12 @@ public:
 
   // Store that the return value is precise.
   void setPrecise(bool p = true) { precise = p; }
+  // Store that the function should not be inlined.
+  void setNoInline(bool n = true) { noInline = n; }
   // Returns whether the return value is precise.
   bool isPrecise() const { return precise; }
+  // Returns whether the function is marked as no inline
+  bool isNoInline() const { return noInline; }
 
   void setSourceLocation(SourceLocation loc) { functionLoc = loc; }
   SourceLocation getSourceLocation() const { return functionLoc; }
@@ -117,6 +122,7 @@ private:
   SpirvType *fnType;      ///< The SPIR-V function type
   bool relaxedPrecision;  ///< Whether the return type is at relaxed precision
   bool precise;           ///< Whether the return value is 'precise'
+  bool noInline;          ///< The function is marked as no inline
 
   /// Legalization-specific code
   ///

+ 3 - 1
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -1201,13 +1201,15 @@ SpirvFunction *DeclResultIdMapper::getOrRegisterFn(const FunctionDecl *fn) {
   (void)getTypeAndCreateCounterForPotentialAliasVar(fn, &isAlias);
 
   const bool isPrecise = fn->hasAttr<HLSLPreciseAttr>();
+  const bool isNoInline = fn->hasAttr<NoInlineAttr>();
   // Note: we do not need to worry about function parameter types at this point
   // as this is used when function declarations are seen. When function
   // definition is seen, the parameter types will be set properly and take into
   // account whether the function is a member function of a class/struct (in
   // which case a 'this' parameter is added at the beginnig).
   SpirvFunction *spirvFunction = spvBuilder.createSpirvFunction(
-      fn->getReturnType(), fn->getLocation(), fn->getName(), isPrecise);
+      fn->getReturnType(), fn->getLocation(), fn->getName(), isPrecise,
+      isNoInline);
 
   // No need to dereference to get the pointer. Function returns that are
   // stand-alone aliases are already pointers to values. All other cases should

+ 2 - 1
tools/clang/lib/SPIRV/EmitVisitor.cpp

@@ -376,7 +376,8 @@ bool EmitVisitor::visit(SpirvFunction *fn, Phase phase) {
     initInstruction(spv::Op::OpFunction, fn->getSourceLocation());
     curInst.push_back(returnTypeId);
     curInst.push_back(getOrAssignResultId<SpirvFunction>(fn));
-    curInst.push_back(
+    curInst.push_back(fn->isNoInline() ?
+        static_cast<uint32_t>(spv::FunctionControlMask::DontInline) :
         static_cast<uint32_t>(spv::FunctionControlMask::MaskNone));
     curInst.push_back(functionTypeId);
     finalizeInstruction(&mainBinary);

+ 8 - 3
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -32,8 +32,10 @@ SpirvBuilder::SpirvBuilder(ASTContext &ac, SpirvContext &ctx,
 SpirvFunction *SpirvBuilder::createSpirvFunction(QualType returnType,
                                                  SourceLocation loc,
                                                  llvm::StringRef name,
-                                                 bool isPrecise) {
-  auto *fn = new (context) SpirvFunction(returnType, loc, name, isPrecise);
+                                                 bool isPrecise,
+                                                 bool isNoInline) {
+  auto *fn = new (context) SpirvFunction(returnType, loc, name, isPrecise,
+                                         isNoInline);
   mod->addFunction(fn);
   return fn;
 }
@@ -42,6 +44,7 @@ SpirvFunction *SpirvBuilder::beginFunction(QualType returnType,
                                            SourceLocation loc,
                                            llvm::StringRef funcName,
                                            bool isPrecise,
+                                           bool isNoInline,
                                            SpirvFunction *func) {
   assert(!function && "found nested function");
   if (func) {
@@ -50,8 +53,10 @@ SpirvFunction *SpirvBuilder::beginFunction(QualType returnType,
     function->setSourceLocation(loc);
     function->setFunctionName(funcName);
     function->setPrecise(isPrecise);
+    function->setNoInline(isNoInline);
   } else {
-    function = createSpirvFunction(returnType, loc, funcName, isPrecise);
+    function = createSpirvFunction(returnType, loc, funcName, isPrecise,
+                                   isNoInline);
   }
 
   return function;

+ 2 - 1
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -1016,7 +1016,8 @@ void SpirvEmitter::doFunctionDecl(const FunctionDecl *decl) {
       declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(decl);
 
   spvBuilder.beginFunction(retType, decl->getLocStart(), funcName,
-                           decl->hasAttr<HLSLPreciseAttr>(), func);
+                           decl->hasAttr<HLSLPreciseAttr>(),
+                           decl->hasAttr<NoInlineAttr>(), func);
 
   auto loc = decl->getLocStart();
   RichDebugInfo *info = nullptr;

+ 5 - 3
tools/clang/lib/SPIRV/SpirvFunction.cpp

@@ -16,11 +16,13 @@ namespace clang {
 namespace spirv {
 
 SpirvFunction::SpirvFunction(QualType returnType, SourceLocation loc,
-                             llvm::StringRef name, bool isPrecise)
+                             llvm::StringRef name, bool isPrecise,
+                             bool isNoInline)
     : functionId(0), astReturnType(returnType), returnType(nullptr),
       fnType(nullptr), relaxedPrecision(false), precise(isPrecise),
-      containsAlias(false), rvalue(false), functionLoc(loc), functionName(name),
-      isWrapperOfEntry(false), debugScope(nullptr) {}
+      noInline(isNoInline), containsAlias(false), rvalue(false),
+      functionLoc(loc), functionName(name), isWrapperOfEntry(false),
+      debugScope(nullptr) {}
 
 SpirvFunction::~SpirvFunction() {
   for (auto *param : parameters)

+ 14 - 0
tools/clang/test/CodeGenSPIRV/fn.noinline.hlsl

@@ -0,0 +1,14 @@
+// Run: %dxc -T ps_6_0 -E main
+
+[noinline]
+float4 foo()
+{
+    return 0;
+}
+
+void main()
+{
+    foo();
+}
+
+// CHECK:  %foo = OpFunction %v4float DontInline {{%\d+}}

+ 2 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -586,6 +586,8 @@ TEST_F(FileTest, FunctionInCTBuffer) {
   runFileTest("fn.ctbuffer.hlsl");
 }
 
+TEST_F(FileTest, FunctionNoInline) { runFileTest("fn.noinline.hlsl"); }
+
 // For OO features
 TEST_F(FileTest, StructMethodCall) {
   setBeforeHLSLLegalization();