Browse Source

Fix crash when remove unused globals in rewriter and support remove types. (#2933)

Xiang Li 5 years ago
parent
commit
7e780aef6f

+ 1 - 0
include/dxc/Support/HLSLOptions.h

@@ -94,6 +94,7 @@ struct RewriterOpts {
   bool KeepUserMacro = false;               // OPT_rw_keep_user_macro
   bool ExtractEntryUniforms = false;        // OPT_rw_extract_entry_uniforms
   bool RemoveUnusedGlobals = false;         // OPT_rw_remove_unused_globals
+  bool RemoveUnusedFunctions = false;         // OPT_rw_remove_unused_functions
 };
 
 /// Use this class to capture all options.

+ 2 - 1
include/dxc/Support/HLSLOptions.td

@@ -451,6 +451,7 @@ def rw_extract_entry_uniforms : Flag<["-", "/"], "extract-entry-uniforms">, Grou
   HelpText<"Move uniform parameters from entry point to global scope">;
 def rw_remove_unused_globals : Flag<["-", "/"], "remove-unused-globals">, Group<hlslrewrite_Group>, Flags<[RewriteOption]>,
   HelpText<"Remove unused static globals and functions">;
-
+def rw_remove_unused_functions : Flag<["-", "/"], "remove-unused-functions">, Group<hlslrewrite_Group>, Flags<[RewriteOption]>,
+  HelpText<"Remove unused functions and types">;
 // Also removed: compress, decompress, /Gch (child effect), /Gpp (partial precision)
 // /Op - no support for preshaders.

+ 4 - 2
lib/DxcSupport/HLSLOptions.cpp

@@ -894,10 +894,12 @@ int ReadDxcOpts(const OptTable *optionTable, unsigned flagsToInclude,
     opts.RWOpt.KeepUserMacro = Args.hasFlag(OPT_rw_keep_user_macro, OPT_INVALID, false);
     opts.RWOpt.ExtractEntryUniforms = Args.hasFlag(OPT_rw_extract_entry_uniforms, OPT_INVALID, false);
     opts.RWOpt.RemoveUnusedGlobals = Args.hasFlag(OPT_rw_remove_unused_globals, OPT_INVALID, false);
+    opts.RWOpt.RemoveUnusedFunctions = Args.hasFlag(OPT_rw_remove_unused_functions, OPT_INVALID, false);
 
     if (opts.EntryPoint.empty() &&
-        (opts.RWOpt.RemoveUnusedGlobals || opts.RWOpt.ExtractEntryUniforms)) {
-      errors << "-rw-remove-unused-globals and -rw-extract-entry-uniforms requires entry point (-E) to be specified.";
+        (opts.RWOpt.RemoveUnusedGlobals || opts.RWOpt.ExtractEntryUniforms ||
+         opts.RWOpt.RemoveUnusedFunctions)) {
+      errors << "-remove-unused-globals, -remove-unused-functions and -extract-entry-uniforms requires entry point (-E) to be specified.";
       return 1;
     }
   }

+ 21 - 0
tools/clang/test/HLSLFileCheck/rewriter/ConstantBuffer.hlsl

@@ -0,0 +1,21 @@
+// RUN: %dxr -E main -remove-unused-globals %s | FileCheck %s
+// RUN: %dxr -E main -remove-unused-functions %s | FileCheck %s -check-prefix=KEEP_GLOBAL
+
+// CHECK-NOT:struct
+// CHECK-NOT:ConstantBuffer.
+// CHECK:float main
+
+// KEEP_GLOBAL:struct
+// KEEP_GLOBAL:ConstantBuffer.
+// KEEP_GLOBAL:float main
+
+
+struct ST
+{
+ uint t;
+};
+ConstantBuffer<ST> cbv;
+
+float main() : SV_Target {
+  return 1;
+}

+ 16 - 0
tools/clang/test/HLSLFileCheck/rewriter/overload.hlsl

@@ -0,0 +1,16 @@
+// RUN: %dxr -E main -remove-unused-globals %s | FileCheck %s
+
+// CHECK-NOT:foo
+// CHECK:float main
+
+float foo(float a) { return a; }
+
+float2 foo(float2 a) { return a; }
+
+float foo(float a);
+
+float a;
+
+float main() : SV_Target {
+  return 1;
+}

+ 14 - 0
tools/clang/test/HLSLFileCheck/rewriter/overload2.hlsl

@@ -0,0 +1,14 @@
+// RUN: %dxr -E main -remove-unused-globals %s | FileCheck %s
+
+// CHECK:foo
+// CHECK-NOT:foo(float2
+// CHECK:float main
+// CHECK:foo(1.2)
+
+float foo(float a) { return a; }
+
+float2 foo(float2 a) { return a; }
+
+float main() : SV_Target {
+  return foo(1.2);
+}

+ 16 - 0
tools/clang/test/HLSLFileCheck/rewriter/overload3.hlsl

@@ -0,0 +1,16 @@
+// RUN: %dxr -E main -remove-unused-globals %s | FileCheck %s
+
+// CHECK:foo
+// CHECK-NOT:foo
+// CHECK:float main
+// CHECK:foo(1.2)
+
+
+float foo(float a) { return a; }
+
+float2 foo(float2 a) { return a; }
+float foo(float a);
+float2 foo(float2 a);
+float main() : SV_Target {
+  return foo(1.2);
+}

+ 95 - 13
tools/clang/tools/libclang/dxcrewriteunused.cpp

@@ -67,31 +67,60 @@ public:
     m_sema = nullptr;
   }
 };
+static FunctionDecl *getFunctionWithBody(FunctionDecl *F) {
+  if (!F)
+    return nullptr;
+  if (F->doesThisDeclarationHaveABody())
+    return F;
+  F = F->getFirstDecl();
+  for (auto &&Candidate : F->redecls()) {
+    if (Candidate->doesThisDeclarationHaveABody()) {
+      return Candidate;
+    }
+  }
+  return nullptr;
+}
 
 class VarReferenceVisitor : public RecursiveASTVisitor<VarReferenceVisitor> {
 private:
   SmallPtrSetImpl<VarDecl*>& m_unusedGlobals;
   SmallPtrSetImpl<FunctionDecl*>& m_visitedFunctions;
   SmallVectorImpl<FunctionDecl*>& m_pendingFunctions;
+  SmallPtrSetImpl<TypeDecl *> &m_visitedTypes;
+
 public:
   VarReferenceVisitor(
     SmallPtrSetImpl<VarDecl*>& unusedGlobals,
     SmallPtrSetImpl<FunctionDecl*>& visitedFunctions,
-    SmallVectorImpl<FunctionDecl*>& pendingFunctions) :
+    SmallVectorImpl<FunctionDecl*>& pendingFunctions,
+    SmallPtrSetImpl<TypeDecl *> &types)  :
     m_unusedGlobals(unusedGlobals),
     m_visitedFunctions(visitedFunctions),
-    m_pendingFunctions(pendingFunctions) {
+    m_pendingFunctions(pendingFunctions),
+    m_visitedTypes(types) {
   }
 
   bool VisitDeclRefExpr(DeclRefExpr* ref) {
     ValueDecl* valueDecl = ref->getDecl();
     if (FunctionDecl* fnDecl = dyn_cast_or_null<FunctionDecl>(valueDecl)) {
-      if (!m_visitedFunctions.count(fnDecl)) {
-        m_pendingFunctions.push_back(fnDecl);
+      FunctionDecl *fnDeclWithbody = getFunctionWithBody(fnDecl);
+      if (fnDeclWithbody) {
+        if (!m_visitedFunctions.count(fnDeclWithbody)) {
+          m_pendingFunctions.push_back(fnDeclWithbody);
+        }
+      }
+      if (fnDeclWithbody != fnDecl) {
+        // In case fnDecl is only a decl, setDecl to fnDeclWithbody.
+        // fnDecl will be removed.
+        ref->setDecl(fnDeclWithbody);
       }
     }
     else if (VarDecl* varDecl = dyn_cast_or_null<VarDecl>(valueDecl)) {
       m_unusedGlobals.erase(varDecl);
+      if (TagDecl *tagDecl = varDecl->getType()->getAsTagDecl()) {
+        m_visitedTypes.insert(tagDecl);
+      }
+      varDecl->getType();
     }
     return true;
   }
@@ -102,6 +131,9 @@ public:
         m_pendingFunctions.push_back(fnDecl);
       }
     }
+    if (CXXRecordDecl *recordDecl = expr->getRecordDecl()) {
+      m_visitedTypes.insert(recordDecl);
+    }
     return true;
   }
 };
@@ -451,6 +483,8 @@ void PrintTranslationUnitWithTranslatedUniformParams(
 
 static HRESULT DoRewriteUnused( TranslationUnitDecl *tu,
                                 LPCSTR pEntryPoint,
+                                bool bRemoveGlobals,
+                                bool bRemoveFunctions,
                                 raw_ostream &w) {
   ASTContext& C = tu->getASTContext();
 
@@ -458,11 +492,22 @@ static HRESULT DoRewriteUnused( TranslationUnitDecl *tu,
   SmallPtrSet<VarDecl*, 128> unusedGlobals;
   DenseMap<RecordDecl*, unsigned> anonymousRecordRefCounts;
   SmallPtrSet<FunctionDecl*, 128> unusedFunctions;
+  SmallPtrSet<TypeDecl*, 32> unusedTypes;
+  SmallVector<VarDecl*, 32> nonStaticGlobals;
   for (Decl *tuDecl : tu->decls()) {
     if (tuDecl->isImplicit()) continue;
 
     VarDecl* varDecl = dyn_cast_or_null<VarDecl>(tuDecl);
     if (varDecl != nullptr) {
+      if (!bRemoveGlobals) {
+        // Only remove static global when not remove global.
+        if (!(varDecl->getStorageClass() == SC_Static ||
+            varDecl->isInAnonymousNamespace())) {
+          nonStaticGlobals.emplace_back(varDecl);
+          continue;
+        }
+      }
+
       unusedGlobals.insert(varDecl);
       if (const RecordType *recordType = varDecl->getType()->getAs<RecordType>()) {
         RecordDecl *recordDecl = recordType->getDecl();
@@ -475,10 +520,16 @@ static HRESULT DoRewriteUnused( TranslationUnitDecl *tu,
 
     FunctionDecl* fnDecl = dyn_cast_or_null<FunctionDecl>(tuDecl);
     if (fnDecl != nullptr) {
-      if (fnDecl->doesThisDeclarationHaveABody()) {
+      FunctionDecl *fnDeclWithbody = getFunctionWithBody(fnDecl);
+      // Add fnDecl without body which has a define somewhere.
+      if (fnDecl->doesThisDeclarationHaveABody() || fnDeclWithbody) {
         unusedFunctions.insert(fnDecl);
       }
     }
+
+    if (TagDecl *tagDecl = dyn_cast<TagDecl>(tuDecl)) {
+      unusedTypes.insert(tagDecl);
+    }
   }
 
   w << "//found " << unusedGlobals.size() << " globals as candidates for removal\n";
@@ -501,16 +552,18 @@ static HRESULT DoRewriteUnused( TranslationUnitDecl *tu,
   // Traverse reachable functions and variables.
   SmallPtrSet<FunctionDecl*, 128> visitedFunctions;
   SmallVector<FunctionDecl*, 32> pendingFunctions;
-  VarReferenceVisitor visitor(unusedGlobals, visitedFunctions, pendingFunctions);
+  SmallPtrSet<TypeDecl*, 32> visitedTypes;
+  VarReferenceVisitor visitor(unusedGlobals, visitedFunctions, pendingFunctions,
+                              visitedTypes);
   pendingFunctions.push_back(entryFnDecl);
-  while (!pendingFunctions.empty() && !unusedGlobals.empty()) {
+  while (!pendingFunctions.empty()) {
     FunctionDecl* pendingDecl = pendingFunctions.pop_back_val();
     visitedFunctions.insert(pendingDecl);
     visitor.TraverseDecl(pendingDecl);
   }
 
   // Don't bother doing work if there are no globals to remove.
-  if (unusedGlobals.empty()) {
+  if (unusedGlobals.empty() && unusedFunctions.empty() && unusedTypes.empty()) {
     return S_FALSE;
   }
 
@@ -522,6 +575,18 @@ static HRESULT DoRewriteUnused( TranslationUnitDecl *tu,
   }
   w << "//found " << unusedFunctions.size() << " functions to remove\n";
 
+  for (TypeDecl *typeDecl : visitedTypes) {
+    unusedTypes.erase(typeDecl);
+  }
+
+  for (VarDecl *varDecl : nonStaticGlobals) {
+    if (TagDecl *tagDecl = varDecl->getType()->getAsTagDecl()) {
+      unusedTypes.erase(tagDecl);
+    }
+  }
+  w << "//found " << unusedTypes.size() << " types to remove\n";
+
+
   // Remove all unused variables and functions.
   for (VarDecl *unusedGlobal : unusedGlobals) {
     if (const RecordType *recordTy = unusedGlobal->getType()->getAs<RecordType>()) {
@@ -539,14 +604,26 @@ static HRESULT DoRewriteUnused( TranslationUnitDecl *tu,
         }
       }
     }
-
+    if (HLSLBufferDecl *CBV = dyn_cast<HLSLBufferDecl>(unusedGlobal->getLexicalDeclContext())) {
+      if (CBV->isConstantBufferView()) {
+        // For constant buffer view, we create a variable for the constant.
+        // The variable use tu as the DeclContext to access as global variable, CBV as LexicalDeclContext so it is still part of CBV.
+        // setLexicalDeclContext to tu to avoid assert when remove.
+        unusedGlobal->setLexicalDeclContext(tu);
+      }
+    }
     tu->removeDecl(unusedGlobal);
   }
 
   for (FunctionDecl *unusedFn : unusedFunctions) {
+    // remove name of function to workaround assert when update lookup table.
+    unusedFn->setDeclName(DeclarationName());
     tu->removeDecl(unusedFn);
   }
 
+  for (TypeDecl *unusedTy : unusedTypes) {
+    tu->removeDecl(unusedTy);
+  }
   // Flush and return results.
   w.flush();
   return S_OK;
@@ -558,6 +635,8 @@ HRESULT DoRewriteUnused(_In_ DxcLangExtensionsHelper *pHelper,
                      _In_ ASTUnit::RemappedFile *pRemap,
                      _In_ LPCSTR pEntryPoint,
                      _In_ LPCSTR pDefines,
+                     bool bRemoveGlobals,
+                     bool bRemoveFunctions,
                      std::string &warnings,
                      std::string &result,
                      _In_opt_ dxcutil::DxcArgsFileSystem *msfPtr) {
@@ -585,7 +664,8 @@ HRESULT DoRewriteUnused(_In_ DxcLangExtensionsHelper *pHelper,
   if (compiler.getDiagnosticClient().getNumErrors() > 0)
     return E_FAIL;
 
-  HRESULT hr = DoRewriteUnused(tu, pEntryPoint, w);
+  HRESULT hr =
+      DoRewriteUnused(tu, pEntryPoint, bRemoveGlobals, bRemoveFunctions, w);
   if (FAILED(hr))
     return hr;
 
@@ -692,8 +772,10 @@ HRESULT DoSimpleReWrite(_In_ DxcLangExtensionsHelper *pHelper,
   if (opts.EntryPoint.empty())
     opts.EntryPoint = "main";
 
-  if (opts.RWOpt.RemoveUnusedGlobals) {
-    HRESULT hr = DoRewriteUnused(tu, opts.EntryPoint.data(), w);
+  if (opts.RWOpt.RemoveUnusedGlobals || opts.RWOpt.RemoveUnusedFunctions) {
+    HRESULT hr = DoRewriteUnused(tu, opts.EntryPoint.data(),
+                                 opts.RWOpt.RemoveUnusedGlobals,
+                                     opts.RWOpt.RemoveUnusedFunctions, w);
     if (FAILED(hr))
       return hr;
   } else {
@@ -785,7 +867,7 @@ public:
       LPCWSTR pOutputName = nullptr;  // TODO: Fill this in
       HRESULT status = DoRewriteUnused(
           &m_langExtensionsHelper, fakeName, pRemap.get(), utf8EntryPoint,
-          defineCount > 0 ? definesStr.c_str() : nullptr, errors, rewrite, nullptr);
+          defineCount > 0 ? definesStr.c_str() : nullptr, true/*removeGlobals*/, false/*removeFunctions*/,errors, rewrite, nullptr);
       return DxcResult::Create(status, DXC_OUT_HLSL, {
           DxcOutputObject::StringOutput(DXC_OUT_HLSL, CP_UTF8,  // TODO: Support DefaultTextCodePage
             rewrite.c_str(), pOutputName),