Browse Source

Add SkipStatic to rewriter. (#394)

Xiang Li 8 years ago
parent
commit
fe1e5ccc61

+ 1 - 1
include/dxc/dxctools.h

@@ -17,7 +17,7 @@
 enum RewirterOptionMask {
   Default = 0,
   SkipFunctionBody = 1,
-
+  SkipStatic = 2,
 };
 
 struct __declspec(uuid("c012115b-8893-4eb9-9c5a-111456ea1c45"))

+ 0 - 2
tools/clang/tools/dxcompiler/dxclinker.cpp

@@ -126,8 +126,6 @@ DxcLinker::RegisterLibrary(_In_opt_ LPCWSTR pLibName, // Name of the library.
 
     raw_stream_ostream DiagStream(pDiagStream);
 
-    llvm::DiagnosticPrinterRawOStream DiagPrinter(DiagStream);
-
     IFR(ValidateLoadModuleFromContainer(
         pBlob->GetBufferPointer(), pBlob->GetBufferSize(), pModule,
         pDebugModule, m_Ctx, m_Ctx, DiagStream));

+ 33 - 5
tools/clang/tools/libclang/dxcrewriteunused.cpp

@@ -351,17 +351,42 @@ HRESULT DoRewriteUnused(_In_ DxcLangExtensionsHelper *pHelper,
   return S_OK;
 }
 
+static void RemoveStaticDecls(DeclContext &Ctx) {
+  for (auto it = Ctx.decls_begin(); it != Ctx.decls_end(); ) {
+    auto cur = it++;
+    if (VarDecl *VD = dyn_cast<VarDecl>(*cur)) {
+      if (VD->getStorageClass() == SC_Static || VD->isInAnonymousNamespace()) {
+        Ctx.removeDecl(VD);
+      }
+    }
+    if (FunctionDecl *FD = dyn_cast<FunctionDecl>(*cur)) {
+      if (isa<CXXMethodDecl>(FD))
+        continue;
+      if (FD->getStorageClass() == SC_Static || FD->isInAnonymousNamespace()) {
+        Ctx.removeDecl(FD);
+      }
+    }
+
+    if (DeclContext *DC = dyn_cast<DeclContext>(*cur)) {
+      RemoveStaticDecls(*DC);
+    }
+  }
+}
+
 static
 HRESULT DoSimpleReWrite(_In_ DxcLangExtensionsHelper *pHelper,
                _In_ LPCSTR pFileName,
                _In_ ASTUnit::RemappedFile *pRemap,
                _In_ LPCSTR pDefines,
-               _In_ bool bSkipFunctionBody,
+               _In_ UINT32 rewriteOption,
                _Outptr_result_z_ LPSTR *pWarnings,
                _Outptr_result_z_ LPSTR *pResult) {
   if (pWarnings != nullptr) *pWarnings = nullptr;
   if (pResult != nullptr) *pResult = nullptr;
 
+  bool bSkipFunctionBody = rewriteOption & RewirterOptionMask::SkipFunctionBody;
+  bool bSkipStatic = rewriteOption & RewirterOptionMask::SkipStatic;
+
   std::string s, warnings;
   raw_string_ostream o(s);
   raw_string_ostream w(warnings);
@@ -379,6 +404,11 @@ HRESULT DoSimpleReWrite(_In_ DxcLangExtensionsHelper *pHelper,
   ASTContext& C = compiler.getASTContext();
   TranslationUnitDecl *tu = C.getTranslationUnitDecl();
 
+  if (bSkipStatic && bSkipFunctionBody) {
+    // Remove static functions and globals.
+    RemoveStaticDecls(*tu);
+  }
+
   o << "// Rewrite unchanged result:\n";
   PrintingPolicy p = PrintingPolicy(C.getPrintingPolicy());
   p.Indentation = 1;
@@ -485,7 +515,7 @@ public:
       HRESULT status =
           DoSimpleReWrite(&m_langExtensionsHelper, fakeName, pRemap.get(),
                           defineCount > 0 ? definesStr.c_str() : nullptr,
-                          /*bSkipFunctionBody*/ false, &errors, &rewrite);
+                          RewirterOptionMask::Default, &errors, &rewrite);
 
       return DxcOperationResult::CreateFromUtf8Strings(errors, rewrite, status,
                                                        ppResult);
@@ -530,12 +560,10 @@ public:
 
       LPSTR errors = nullptr;
       LPSTR rewrite = nullptr;
-      bool bSkipFunctionBody =
-          rewriteOption & RewirterOptionMask::SkipFunctionBody;
       HRESULT status =
           DoSimpleReWrite(&m_langExtensionsHelper, fName, pRemap.get(),
                           defineCount > 0 ? definesStr.c_str() : nullptr,
-                          bSkipFunctionBody, &errors, &rewrite);
+                          rewriteOption, &errors, &rewrite);
 
       return DxcOperationResult::CreateFromUtf8Strings(errors, rewrite, status,
                                                        ppResult);

+ 30 - 0
tools/clang/unittests/HLSL/RewriterTest.cpp

@@ -75,6 +75,7 @@ public:
   TEST_METHOD(RunSemanticDefines);
   TEST_METHOD(RunNoFunctionBody);
   TEST_METHOD(RunNoFunctionBodyInclude);
+  TEST_METHOD(RunNoStatic);
 
   dxc::DxcDllSupport m_dllSupport;
   CComPtr<IDxcIncludeHandler> m_pIncludeHandler;
@@ -491,4 +492,33 @@ TEST_F(RewriterTest, RunNoFunctionBody) {
   VERIFY_IS_TRUE(strcmp(BlobToUtf8(result).c_str(),
                         "// Rewrite unchanged result:\nfloat pick_one(float2 "
                         "f2);\nvoid main();\n") == 0);
+}
+
+TEST_F(RewriterTest, RunNoStatic) {
+  CComPtr<IDxcRewriter> pRewriter;
+  VERIFY_SUCCEEDED(CreateRewriter(&pRewriter));
+  CComPtr<IDxcOperationResult> pRewriteResult;
+
+  // Get the source text from a file
+  FileWithBlob source(
+      m_dllSupport,
+      GetPathToHlslDataFile(L"rewriter\\attributes_noerr.hlsl")
+          .c_str());
+
+  const int myDefinesCount = 3;
+  DxcDefine myDefines[myDefinesCount] = {
+      {L"myDefine", L"2"}, {L"myDefine3", L"1994"}, {L"myDefine4", nullptr}};
+
+  // Run rewrite no function body on the source code
+  VERIFY_SUCCEEDED(pRewriter->RewriteUnchangedWithInclude(
+      source.BlobEncoding, L"attributes_noerr.hlsl", myDefines, myDefinesCount,
+      /*pIncludeHandler*/ nullptr,
+      RewirterOptionMask::SkipFunctionBody | RewirterOptionMask::SkipStatic,
+      &pRewriteResult));
+
+  CComPtr<IDxcBlob> result;
+  VERIFY_SUCCEEDED(pRewriteResult->GetResult(&result));
+  std::string strResult = BlobToUtf8(result);
+  // No static.
+  VERIFY_IS_TRUE(strResult.find("static") == std::string::npos);
 }