Sfoglia il codice sorgente

Fix allocator issues for dxc blob allocations (#2130)

DxcCreateBlobOnHeapCopy allocated using the COM allocator and freed using the default thread allocator, leading to crashes.

But more generally, some DxcCreateBlob functions would take ownership of the passed-in buffer without also having the caller explicitly specify what IMalloc to use for deallocation, which is an error-prone pattern, so I reworked these methods a bit.

This also lead me to find some memory leaks and general memory mismanagement in the the rewriter.
Tristan Labelle 6 anni fa
parent
commit
d56f6eb93f

+ 11 - 9
include/dxc/Support/FileIOHelper.h

@@ -77,6 +77,8 @@ public:
 
   operator T *() const throw() { return m_pData; }
 
+  IMalloc* GetMallocNoRef() const throw() { return m_pMalloc.p; }
+
   bool Allocate(_In_ SIZE_T ElementCount) throw() {
     ATLASSERT(m_pData == NULL);
     SIZE_T nBytes = ElementCount * sizeof(T);
@@ -142,10 +144,12 @@ HRESULT DxcCreateBlobFromBlob(_In_ IDxcBlob *pBlob, UINT32 offset,
                               UINT32 length,
                               _COM_Outptr_ IDxcBlob **ppResult) throw();
 
+// Creates a blob wrapping a buffer to be freed with the provided IMalloc
 HRESULT
-DxcCreateBlobOnHeap(_In_bytecount_(size) LPCVOID pData, UINT32 size,
-                    _COM_Outptr_ IDxcBlob **ppResult) throw();
+DxcCreateBlobOnMalloc(_In_bytecount_(size) LPCVOID pData, _In_ IMalloc* pIMalloc,
+                      UINT32 size, _COM_Outptr_ IDxcBlob **ppResult) throw();
 
+// Creates a blob with a copy of the provided data
 HRESULT
 DxcCreateBlobOnHeapCopy(_In_bytecount_(size) LPCVOID pData, UINT32 size,
                         _COM_Outptr_ IDxcBlob **ppResult) throw();
@@ -159,6 +163,7 @@ DxcCreateBlobWithEncodingSet(
     _In_ IMalloc *pMalloc, _In_ IDxcBlob *pBlob, UINT32 codePage,
     _COM_Outptr_ IDxcBlobEncoding **ppBlobEncoding) throw();
 
+// Creates a blob around encoded text without ownership transfer
 HRESULT DxcCreateBlobWithEncodingFromPinned(
     _In_bytecount_(size) LPCVOID pText, UINT32 size, UINT32 codePage,
     _COM_Outptr_ IDxcBlobEncoding **pBlobEncoding) throw();
@@ -168,22 +173,19 @@ DxcCreateBlobWithEncodingFromStream(
     IStream *pStream, bool newInstanceAlways, UINT32 codePage,
     _COM_Outptr_ IDxcBlobEncoding **pBlobEncoding) throw();
 
-HRESULT
-DxcCreateBlobWithEncodingOnHeap(_In_bytecount_(size) LPCVOID pText, UINT32 size,
-                                UINT32 codePage,
-                                _COM_Outptr_ IDxcBlobEncoding **pBlobEncoding) throw();
-
-// Should rename this 'OnHeap' to be 'OnMalloc', change callers to pass arg. Using TLS.
+// Creates a blob with a copy of the encoded text
 HRESULT
 DxcCreateBlobWithEncodingOnHeapCopy(
     _In_bytecount_(size) LPCVOID pText, UINT32 size, UINT32 codePage,
     _COM_Outptr_ IDxcBlobEncoding **pBlobEncoding) throw();
 
+// Creates a blob wrapping encoded text to be freed with the provided IMalloc
 HRESULT
 DxcCreateBlobWithEncodingOnMalloc(
-  _In_bytecount_(size) LPCVOID pText, IMalloc *pIMalloc, UINT32 size, UINT32 codePage,
+  _In_bytecount_(size) LPCVOID pText, _In_ IMalloc *pIMalloc, UINT32 size, UINT32 codePage,
   _COM_Outptr_ IDxcBlobEncoding **pBlobEncoding) throw();
 
+// Creates a blob with a copy of encoded text, allocated using the provided IMalloc
 HRESULT
 DxcCreateBlobWithEncodingOnMallocCopy(
   _In_ IMalloc *pIMalloc, _In_bytecount_(size) LPCVOID pText, UINT32 size, UINT32 codePage,

+ 1 - 2
include/dxc/Support/dxcapi.impl.h

@@ -74,7 +74,6 @@ public:
     *pResult = nullptr;
     CComPtr<IDxcBlobEncoding> resultBlob;
     CComPtr<IDxcBlobEncoding> errorBlob;
-    CComPtr<DxcOperationResult> result;
 
     HRESULT hr = S_OK;
 
@@ -87,7 +86,7 @@ public:
     }
 
     if (pResultStr != nullptr) {
-      hr = hlsl::DxcCreateBlobWithEncodingOnHeap(
+      hr = hlsl::DxcCreateBlobWithEncodingOnHeapCopy(
         pResultStr, strlen(pResultStr), CP_UTF8, &resultBlob);
       if (FAILED(hr)) {
         return hr;

+ 10 - 31
lib/DxcSupport/FileIOHelper.cpp

@@ -206,14 +206,6 @@ public:
     }
   }
 
-  static HRESULT
-  CreateFromHeap(LPCVOID buffer, SIZE_T bufferSize, bool encodingKnown,
-                 UINT32 codePage,
-                 _COM_Outptr_ InternalDxcBlobEncoding **ppEncoding) {
-    return CreateFromMalloc(buffer, DxcGetThreadMallocNoRef(), bufferSize,
-                            encodingKnown, codePage, ppEncoding);
-  }
-
   static HRESULT
   CreateFromBlob(_In_ IDxcBlob *pBlob, _In_ IMalloc *pMalloc, bool encodingKnown, UINT32 codePage,
                  _COM_Outptr_ InternalDxcBlobEncoding **pEncoding) {
@@ -346,14 +338,14 @@ HRESULT DxcCreateBlobFromBlob(
 
 _Use_decl_annotations_
 HRESULT
-DxcCreateBlobOnHeap(LPCVOID pData, UINT32 size, IDxcBlob **ppResult) throw() {
+DxcCreateBlobOnMalloc(LPCVOID pData, IMalloc *pIMalloc, UINT32 size, IDxcBlob **ppResult) throw() {
   if (pData == nullptr || ppResult == nullptr) {
     return E_POINTER;
   }
 
   *ppResult = nullptr;
   CComPtr<InternalDxcBlobEncoding> blob;
-  IFR(InternalDxcBlobEncoding::CreateFromHeap(pData, size, false, 0, &blob));
+  IFR(InternalDxcBlobEncoding::CreateFromMalloc(pData, pIMalloc, size, false, 0, &blob));
   *ppResult = blob.Detach();
   return S_OK;
 }
@@ -368,14 +360,15 @@ DxcCreateBlobOnHeapCopy(_In_bytecount_(size) LPCVOID pData, UINT32 size,
 
   *ppResult = nullptr;
 
-  CComHeapPtr<char> heapCopy;
-  if (!heapCopy.AllocateBytes(size)) {
+  CDxcMallocHeapPtr<char> heapCopy(DxcGetThreadMallocNoRef());
+  if (!heapCopy.Allocate(size)) {
     return E_OUTOFMEMORY;
   }
   memcpy(heapCopy.m_pData, pData, size);
 
   CComPtr<InternalDxcBlobEncoding> blob;
-  IFR(InternalDxcBlobEncoding::CreateFromHeap(heapCopy.m_pData, size, false, 0, &blob));
+  IFR(InternalDxcBlobEncoding::CreateFromMalloc(heapCopy.m_pData,
+    heapCopy.GetMallocNoRef(), size, false, 0, &blob));
   heapCopy.Detach();
   *ppResult = blob.Detach();
   return S_OK;
@@ -453,8 +446,8 @@ HRESULT DxcCreateBlobWithEncodingFromPinned(LPCVOID pText, UINT32 size,
   *pBlobEncoding = nullptr;
 
   InternalDxcBlobEncoding *internalEncoding;
-  HRESULT hr = InternalDxcBlobEncoding::CreateFromHeap(
-      pText, size, true, codePage, &internalEncoding);
+  HRESULT hr = InternalDxcBlobEncoding::CreateFromMalloc(
+      pText, DxcGetThreadMallocNoRef(), size, true, codePage, &internalEncoding);
   if (SUCCEEDED(hr)) {
     internalEncoding->ClearFreeFlag();
     *pBlobEncoding = internalEncoding;
@@ -492,21 +485,6 @@ DxcCreateBlobWithEncodingFromStream(IStream *pStream, bool newInstanceAlways,
   return E_NOTIMPL;
 }
 
-_Use_decl_annotations_
-HRESULT
-DxcCreateBlobWithEncodingOnHeap(LPCVOID pText, UINT32 size, UINT32 codePage,
-                                IDxcBlobEncoding **pBlobEncoding) throw() {
-  *pBlobEncoding = nullptr;
-
-  InternalDxcBlobEncoding *internalEncoding;
-  HRESULT hr = InternalDxcBlobEncoding::CreateFromHeap(
-      pText, size, true, codePage, &internalEncoding);
-  if (SUCCEEDED(hr)) {
-    *pBlobEncoding = internalEncoding;
-  }
-  return hr;
-}
-
 _Use_decl_annotations_
 HRESULT
 DxcCreateBlobWithEncodingOnHeapCopy(LPCVOID pText, UINT32 size, UINT32 codePage,
@@ -520,7 +498,8 @@ DxcCreateBlobWithEncodingOnHeapCopy(LPCVOID pText, UINT32 size, UINT32 codePage,
   memcpy(heapCopy.m_pData, pText, size);
 
   InternalDxcBlobEncoding* internalEncoding;
-  HRESULT hr = InternalDxcBlobEncoding::CreateFromHeap(heapCopy.m_pData, size, true, codePage, &internalEncoding);
+  HRESULT hr = InternalDxcBlobEncoding::CreateFromMalloc(heapCopy.m_pData,
+    heapCopy.GetMallocNoRef(), size, true, codePage, &internalEncoding);
   if (SUCCEEDED(hr)) {
     *pBlobEncoding = internalEncoding;
     heapCopy.Detach();

+ 1 - 1
lib/DxilRootSignature/DxilRootSignatureSerializer.cpp

@@ -271,7 +271,7 @@ void SerializeRootSignatureTemplate(_In_ const T_ROOT_SIGNATURE_DESC* pRootSigna
   DXASSERT_NOMSG((cb & 0x3) == 0);
   IFTBOOL(bytes.Allocate(cb), E_OUTOFMEMORY);
   IFT(Serializer.Compact(bytes.m_pData, cb));
-  IFT(DxcCreateBlobOnHeap(bytes.m_pData, cb, ppBlob));
+  IFT(DxcCreateBlobOnMalloc(bytes.m_pData, bytes.GetMallocNoRef(), cb, ppBlob));
   bytes.Detach(); // Ownership transfered to ppBlob.
 }
 

+ 26 - 32
tools/clang/tools/libclang/dxcrewriteunused.cpp

@@ -363,13 +363,10 @@ HRESULT DoRewriteUnused(_In_ DxcLangExtensionsHelper *pHelper,
                      _In_ ASTUnit::RemappedFile *pRemap,
                      _In_ LPCSTR pEntryPoint,
                      _In_ LPCSTR pDefines,
-                     _Outptr_result_z_ LPSTR *pWarnings,
-                     _Outptr_result_z_ LPSTR *pResult) {
-  if (pWarnings != nullptr) *pWarnings = nullptr;
-  if (pResult != nullptr) *pResult = nullptr;
+                     std::string &warnings,
+                     std::string &result) {
 
-  std::string s, warnings;
-  raw_string_ostream o(s);
+  raw_string_ostream o(result);
   raw_string_ostream w(warnings);
 
   // Setup a compiler instance.
@@ -493,8 +490,8 @@ HRESULT DoRewriteUnused(_In_ DxcLangExtensionsHelper *pHelper,
   }
 
   // Flush and return results.
-  raw_string_ostream_to_CoString(o, pResult);
-  raw_string_ostream_to_CoString(w, pWarnings);
+  o.flush();
+  w.flush();
 
   if (compiler.getDiagnosticClient().getNumErrors() > 0)
     return E_FAIL;
@@ -547,18 +544,15 @@ HRESULT DoSimpleReWrite(_In_ DxcLangExtensionsHelper *pHelper,
                _In_ hlsl::options::DxcOpts &opts,
                _In_ LPCSTR pDefines,
                _In_ UINT32 rewriteOption,
-               _Outptr_result_z_ LPSTR *pWarnings,
-               _Outptr_result_z_ LPSTR *pResult) {
-  if (pWarnings != nullptr) *pWarnings = nullptr;
-  if (pResult != nullptr) *pResult = nullptr;
+               std::string &warnings,
+               std::string &result) {
 
   bool bSkipFunctionBody = rewriteOption & RewriterOptionMask::SkipFunctionBody;
   bool bSkipStatic = rewriteOption & RewriterOptionMask::SkipStatic;
   bool bGlobalExternByDefault = rewriteOption & RewriterOptionMask::GlobalExternByDefault;
   bool bKeepUserMacro = rewriteOption & RewriterOptionMask::KeepUserMacro;
 
-  std::string s, warnings;
-  raw_string_ostream o(s);
+  raw_string_ostream o(result);
   raw_string_ostream w(warnings);
 
   // Setup a compiler instance.
@@ -594,8 +588,8 @@ HRESULT DoSimpleReWrite(_In_ DxcLangExtensionsHelper *pHelper,
     WriteUserMacroDefines(compiler, o);
 
   // Flush and return results.
-  raw_string_ostream_to_CoString(o, pResult);
-  raw_string_ostream_to_CoString(w, pWarnings);
+  o.flush();
+  w.flush();
 
   if (compiler.getDiagnosticClient().getNumErrors() > 0)
     return E_FAIL;
@@ -648,12 +642,12 @@ public:
       CW2A utf8EntryPoint(pEntryPoint, CP_UTF8);
       std::string definesStr = DefinesToString(pDefines, defineCount);
 
-      LPSTR errors = nullptr;
-      LPSTR rewrite = nullptr;
+      std::string errors;
+      std::string rewrite;
       HRESULT status = DoRewriteUnused(
           &m_langExtensionsHelper, fakeName, pRemap.get(), utf8EntryPoint,
-          defineCount > 0 ? definesStr.c_str() : nullptr, &errors, &rewrite);
-      return DxcOperationResult::CreateFromUtf8Strings(errors, rewrite, status,
+          defineCount > 0 ? definesStr.c_str() : nullptr, errors, rewrite);
+      return DxcOperationResult::CreateFromUtf8Strings(errors.c_str(), rewrite.c_str(), status,
                                                        ppResult);
     }
     CATCH_CPP_RETURN_HRESULT();
@@ -692,14 +686,14 @@ public:
       hlsl::options::DxcOpts opts;
       opts.HLSLVersion = 2015;
 
-      LPSTR errors = nullptr;
-      LPSTR rewrite = nullptr;
+      std::string errors;
+      std::string rewrite;
       HRESULT status =
           DoSimpleReWrite(&m_langExtensionsHelper, fakeName, pRemap.get(), opts,
                           defineCount > 0 ? definesStr.c_str() : nullptr,
-                          RewriterOptionMask::Default, &errors, &rewrite);
+                          RewriterOptionMask::Default, errors, rewrite);
 
-      return DxcOperationResult::CreateFromUtf8Strings(errors, rewrite, status,
+      return DxcOperationResult::CreateFromUtf8Strings(errors.c_str(), rewrite.c_str(), status,
                                                        ppResult);
     }
     CATCH_CPP_RETURN_HRESULT();
@@ -743,14 +737,14 @@ public:
       hlsl::options::DxcOpts opts;
       opts.HLSLVersion = 2015;
 
-      LPSTR errors = nullptr;
-      LPSTR rewrite = nullptr;
+      std::string errors;
+      std::string rewrite;
       HRESULT status =
           DoSimpleReWrite(&m_langExtensionsHelper, fName, pRemap.get(), opts,
                           defineCount > 0 ? definesStr.c_str() : nullptr,
-                          rewriteOption, &errors, &rewrite);
+                          rewriteOption, errors, rewrite);
 
-      return DxcOperationResult::CreateFromUtf8Strings(errors, rewrite, status,
+      return DxcOperationResult::CreateFromUtf8Strings(errors.c_str(), rewrite.c_str(), status,
                                                        ppResult);
     }
     CATCH_CPP_RETURN_HRESULT();
@@ -803,14 +797,14 @@ public:
       hlsl::options::DxcOpts opts;
       IFR(ReadOptsAndValidate(pArguments, argCount, opts, ppResult));
 
-      LPSTR errors = nullptr;
-      LPSTR rewrite = nullptr;
+      std::string errors;
+      std::string rewrite;
       HRESULT status =
           DoSimpleReWrite(&m_langExtensionsHelper, fName, pRemap.get(), opts,
                           defineCount > 0 ? definesStr.c_str() : nullptr,
-                          Default, &errors, &rewrite);
+                          Default, errors, rewrite);
 
-      return DxcOperationResult::CreateFromUtf8Strings(errors, rewrite, status,
+      return DxcOperationResult::CreateFromUtf8Strings(errors.c_str(), rewrite.c_str(), status,
                                                        ppResult);
     }
     CATCH_CPP_RETURN_HRESULT();