2
0
Эх сурвалжийг харах

Fix dxc_batch crash. (#1203)

* Fix dxc_batch crash.

* Use Twine as temp and replace StringSet with unordered_set<string>
Xiang Li 7 жил өмнө
parent
commit
d5ba8b7081

+ 7 - 6
tools/clang/tools/dxlib-sample/lib_share_preprocessor.cpp

@@ -20,6 +20,7 @@
 #include "llvm/ADT/StringSet.h"
 #include "llvm/Support/raw_ostream.h"
 #include "llvm/Support/Path.h"
+#include <unordered_set>
 
 using namespace libshare;
 
@@ -65,7 +66,8 @@ public:
                                                                // if not found.
   ) {
     CW2A pUtf8Filename(pFilename);
-    if (m_loadedFileNames.count(pUtf8Filename.m_psz)) {
+    if (m_loadedFileNames.find(pUtf8Filename.m_psz) !=
+        m_loadedFileNames.end()) {
       // Already include this file.
       // Just return empty content.
       static const char kEmptyStr[] = " ";
@@ -95,26 +97,25 @@ public:
         IFT(m_pIncludeHandler->LoadSource(pFilename, ppIncludeSource));
       }
     }
-
     CComPtr<IDxcBlobEncoding> utf8Source;
     IFT(hlsl::DxcGetBlobAsUtf8(*ppIncludeSource, &utf8Source));
 
-    StringRef Data((LPSTR)utf8Source->GetBufferPointer());
-    Twine regionData = file_region + Data + file_region;
-    std::string strRegionData = regionData.str();
+    StringRef Data((LPSTR)utf8Source->GetBufferPointer(), utf8Source->GetBufferSize());
+    std::string strRegionData = (Twine(file_region) + Data + file_region).str();
 
     CComPtr<IDxcBlobEncoding> pEncodingIncludeSource;
     IFT(DxcCreateBlobWithEncodingOnMallocCopy(
         GetGlobalHeapMalloc(), strRegionData.c_str(), strRegionData.size(),
         CP_UTF8, &pEncodingIncludeSource));
     *ppIncludeSource = pEncodingIncludeSource.Detach();
+    m_loadedFileNames.insert(pUtf8Filename.m_psz);
     return S_OK;
   }
 
 private:
   DXC_MICROCOM_REF_FIELD(m_dwRef)
   IDxcIncludeHandler *m_pIncludeHandler;
-  StringSet<> m_loadedFileNames;
+  std::unordered_set<std::string> m_loadedFileNames;
   std::vector<std::string> &m_includePathList;
 };
 

+ 3 - 2
tools/clang/unittests/dxc_batch/dxc_batch.cpp

@@ -789,7 +789,7 @@ int DxcBatchContext::BatchCompile(bool bMultiThread, bool bLibLink) {
   llvm::StringRef source((char *)pSource->GetBufferPointer(),
                          pSource->GetBufferSize());
   llvm::SmallVector<llvm::StringRef, 4> commands;
-  source.split(commands, "\r\n");
+  source.split(commands, "\n");
 
   if (bMultiThread) {
     unsigned int threadNum = std::min<unsigned>(
@@ -800,7 +800,8 @@ int DxcBatchContext::BatchCompile(bool bMultiThread, bool bLibLink) {
       threads[i] = std::thread(empty_fn);
 
     for (unsigned i = 0; i < commands.size(); i++) {
-      llvm::StringRef command = commands[i];
+      // trim to remove /r if exist.
+      llvm::StringRef command = commands[i].trim();
       if (command.empty())
         continue;
       if (command.startswith("//"))