소스 검색

Add support for custom allocators (#390)

Supporting a custom allocator for dxcompiler.
Adds recovery for exceptions and out-of-memory handling.
Add custom allocator support to linker.
Fix for release-only test failure.
Removes assertion about presence of command-line option registration
Marcelo Lopez Ruiz 8 년 전
부모
커밋
d5bb3089cf
100개의 변경된 파일1397개의 추가작업 그리고 351개의 파일을 삭제
  1. 6 0
      include/dxc/Support/DxcLangExtensionsHelper.h
  2. 113 4
      include/dxc/Support/FileIOHelper.h
  3. 48 1
      include/dxc/Support/Global.h
  4. 2 0
      include/dxc/Support/HLSLOptions.h
  5. 1 0
      include/dxc/Support/WinIncludes.h
  6. 12 8
      include/dxc/Support/dxcapi.impl.h
  7. 33 2
      include/dxc/Support/dxcapi.use.h
  8. 2 3
      include/dxc/Support/dxcfilesystem.h
  9. 36 6
      include/dxc/Support/microcom.h
  10. 15 0
      include/dxc/dxcapi.h
  11. 18 5
      include/llvm/ADT/DenseMap.h
  12. 1 1
      include/llvm/ADT/Statistic.h
  13. 1 0
      include/llvm/Analysis/CallGraph.h
  14. 2 0
      include/llvm/Bitcode/BitstreamWriter.h
  15. 2 2
      include/llvm/IR/LegacyPassManagers.h
  16. 5 2
      include/llvm/IR/Metadata.h
  17. 2 3
      include/llvm/IR/User.h
  18. 2 2
      include/llvm/Support/CommandLine.h
  19. 4 0
      include/llvm/Support/Mutex.h
  20. 9 2
      include/llvm/Support/raw_ostream.h
  21. 4 0
      lib/Analysis/DependenceAnalysis.cpp
  22. 28 12
      lib/Analysis/IPA/CallGraph.cpp
  23. 6 0
      lib/Analysis/IPA/CallGraphSCCPass.cpp
  24. 10 1
      lib/Analysis/LoopAccessAnalysis.cpp
  25. 3 0
      lib/Analysis/LoopInfo.cpp
  26. 2 0
      lib/Analysis/LoopPass.cpp
  27. 2 0
      lib/Analysis/RegionPass.cpp
  28. 5 0
      lib/Analysis/ScalarEvolution.cpp
  29. 4 0
      lib/Analysis/ScopedNoAliasAA.cpp
  30. 4 0
      lib/Analysis/TargetLibraryInfo.cpp
  31. 4 0
      lib/Analysis/TypeBasedAliasAnalysis.cpp
  32. 9 1
      lib/Analysis/ValueTracking.cpp
  33. 3 1
      lib/Analysis/regioninfo.cpp
  34. 54 54
      lib/Bitcode/Writer/BitcodeWriter.cpp
  35. 1 0
      lib/DxcSupport/CMakeLists.txt
  36. 205 83
      lib/DxcSupport/FileIOHelper.cpp
  37. 15 2
      lib/DxcSupport/HLSLOptions.cpp
  38. 94 0
      lib/DxcSupport/dxcmem.cpp
  39. 15 14
      lib/HLSL/DxcOptimizer.cpp
  40. 2 6
      lib/HLSL/DxilContainerAssembler.cpp
  41. 14 11
      lib/HLSL/DxilContainerReflection.cpp
  42. 3 1
      lib/HLSL/DxilGenerationPass.cpp
  43. 5 3
      lib/HLSL/DxilRootSignature.cpp
  44. 2 6
      lib/HLSL/DxilValidation.cpp
  45. 17 0
      lib/IR/DiagnosticInfo.cpp
  46. 2 0
      lib/IR/Dominators.cpp
  47. 14 0
      lib/IR/Function.cpp
  48. 2 2
      lib/IR/LLVMContextImpl.cpp
  49. 36 6
      lib/IR/LegacyPassManager.cpp
  50. 14 6
      lib/IR/Metadata.cpp
  51. 5 2
      lib/IR/Module.cpp
  52. 9 0
      lib/IR/User.cpp
  53. 3 3
      lib/IR/Value.cpp
  54. 4 0
      lib/IR/Verifier.cpp
  55. 17 4
      lib/Support/CommandLine.cpp
  56. 5 4
      lib/Support/ErrorHandling.cpp
  57. 4 0
      lib/Support/GraphWriter.cpp
  58. 4 0
      lib/Support/RandomNumberGenerator.cpp
  59. 5 2
      lib/Support/Statistic.cpp
  60. 9 9
      lib/Support/Timer.cpp
  61. 4 3
      lib/Support/Windows/Mutex.inc
  62. 12 0
      lib/Support/raw_ostream.cpp
  63. 14 0
      lib/Transforms/IPO/Inliner.cpp
  64. 4 0
      lib/Transforms/IPO/LowerBitSets.cpp
  65. 4 0
      lib/Transforms/IPO/MergeFunctions.cpp
  66. 22 7
      lib/Transforms/IPO/PassManagerBuilder.cpp
  67. 5 0
      lib/Transforms/Scalar/Float2Int.cpp
  68. 6 0
      lib/Transforms/Scalar/GVN.cpp
  69. 8 0
      lib/Transforms/Scalar/IndVarSimplify.cpp
  70. 4 0
      lib/Transforms/Scalar/JumpThreading.cpp
  71. 4 0
      lib/Transforms/Scalar/LICM.cpp
  72. 5 0
      lib/Transforms/Scalar/LoopDistribute.cpp
  73. 5 0
      lib/Transforms/Scalar/LoopRerollPass.cpp
  74. 4 0
      lib/Transforms/Scalar/LoopRotation.cpp
  75. 20 0
      lib/Transforms/Scalar/LoopUnrollPass.cpp
  76. 4 0
      lib/Transforms/Scalar/LoopUnswitch.cpp
  77. 11 1
      lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
  78. 6 0
      lib/Transforms/Scalar/SROA.cpp
  79. 5 0
      lib/Transforms/Scalar/SampleProfile.cpp
  80. 4 0
      lib/Transforms/Scalar/SimplifyCFGPass.cpp
  81. 5 0
      lib/Transforms/Utils/InlineFunction.cpp
  82. 8 1
      lib/Transforms/Utils/SimplifyCFG.cpp
  83. 6 2
      lib/Transforms/Utils/SimplifyLibCalls.cpp
  84. 3 1
      lib/Transforms/Utils/SymbolRewriter.cpp
  85. 7 2
      tools/clang/lib/AST/DeclarationName.cpp
  86. 1 0
      tools/clang/lib/Basic/SourceManager.cpp
  87. 14 0
      tools/clang/lib/Basic/VirtualFileSystem.cpp
  88. 8 0
      tools/clang/lib/CodeGen/CodeGenAction.cpp
  89. 21 3
      tools/clang/lib/CodeGen/CodeGenModule.cpp
  90. 1 1
      tools/clang/lib/CodeGen/CodeGenModule.h
  91. 6 3
      tools/clang/lib/CodeGen/TargetInfo.cpp
  92. 4 3
      tools/clang/lib/Frontend/ASTUnit.cpp
  93. 16 3
      tools/clang/lib/Frontend/CompilerInstance.cpp
  94. 2 2
      tools/clang/lib/Lex/PPLexerChange.cpp
  95. 1 0
      tools/clang/lib/Lex/Preprocessor.cpp
  96. 0 3
      tools/clang/test/HLSL/pix/removeDiscards.hlsl
  97. 84 9
      tools/clang/tools/dxc/dxc.cpp
  98. 64 17
      tools/clang/tools/dxcompiler/DXCompiler.cpp
  99. 1 0
      tools/clang/tools/dxcompiler/DXCompiler.def
  100. 36 16
      tools/clang/tools/dxcompiler/dxcapi.cpp

+ 6 - 0
include/dxc/Support/DxcLangExtensionsHelper.h

@@ -221,21 +221,27 @@ public:
 // Note that QueryInterface still needs to return the vtable.
 #define DXC_LANGEXTENSIONS_HELPER_IMPL(_helper_field_) \
   __override HRESULT STDMETHODCALLTYPE RegisterIntrinsicTable(_In_ IDxcIntrinsicTable *pTable) { \
+    DxcThreadMalloc TM(m_pMalloc); \
     return (_helper_field_).RegisterIntrinsicTable(pTable); \
   } \
   __override HRESULT STDMETHODCALLTYPE RegisterSemanticDefine(LPCWSTR name) { \
+    DxcThreadMalloc TM(m_pMalloc); \
     return (_helper_field_).RegisterSemanticDefine(name); \
   } \
   __override HRESULT STDMETHODCALLTYPE RegisterSemanticDefineExclusion(LPCWSTR name) { \
+    DxcThreadMalloc TM(m_pMalloc); \
     return (_helper_field_).RegisterSemanticDefineExclusion(name); \
   } \
   __override HRESULT STDMETHODCALLTYPE RegisterDefine(LPCWSTR name) { \
+    DxcThreadMalloc TM(m_pMalloc); \
     return (_helper_field_).RegisterDefine(name); \
   } \
   __override HRESULT STDMETHODCALLTYPE SetSemanticDefineValidator(_In_ IDxcSemanticDefineValidator* pValidator) { \
+    DxcThreadMalloc TM(m_pMalloc); \
     return (_helper_field_).SetSemanticDefineValidator(pValidator); \
   } \
   __override HRESULT STDMETHODCALLTYPE SetSemanticDefineMetaDataName(LPCSTR name) { \
+    DxcThreadMalloc TM(m_pMalloc); \
     return (_helper_field_).SetSemanticDefineMetaDataName(name); \
   } \
 

+ 113 - 4
include/dxc/Support/FileIOHelper.h

@@ -17,6 +17,99 @@ struct IDxcBlobEncoding;
 
 namespace hlsl {
 
+IMalloc *GetGlobalHeapMalloc() throw();
+
+class CDxcThreadMallocAllocator {
+public:
+  _Ret_maybenull_ _Post_writable_byte_size_(nBytes) _ATL_DECLSPEC_ALLOCATOR
+  static void *Reallocate(_In_ void *p, _In_ size_t nBytes) throw() {
+    return DxcGetThreadMallocNoRef()->Realloc(p, nBytes);
+  }
+
+  _Ret_maybenull_ _Post_writable_byte_size_(nBytes) _ATL_DECLSPEC_ALLOCATOR
+  static void *Allocate(_In_ size_t nBytes) throw() {
+    return DxcGetThreadMallocNoRef()->Alloc(nBytes);
+  }
+
+  static void Free(_In_ void *p) throw() {
+    return DxcGetThreadMallocNoRef()->Free(p);
+  }
+};
+
+// Like CComHeapPtr, but with CDxcThreadMallocAllocator.
+template <typename T>
+class CDxcTMHeapPtr :
+  public CHeapPtr<T, CDxcThreadMallocAllocator>
+{
+public:
+  CDxcTMHeapPtr() throw()
+  {
+  }
+
+  explicit CDxcTMHeapPtr(_In_ T* pData) throw() :
+    CDxcTMHeapPtr<T, CDxcThreadMallocAllocator>(pData)
+  {
+  }
+};
+
+// Like CComHeapPtr, but with a stateful allocator.
+template <typename T>
+class CDxcMallocHeapPtr
+{
+private:
+  CComPtr<IMalloc> m_pMalloc;
+public:
+  T *m_pData;
+
+  CDxcMallocHeapPtr(IMalloc *pMalloc) throw()
+      : m_pMalloc(pMalloc), m_pData(nullptr) {}
+
+  ~CDxcMallocHeapPtr() {
+    if (m_pData)
+      m_pMalloc->Free(m_pData);
+  }
+
+  operator T *() const throw() { return m_pData; }
+
+  bool Allocate(_In_ SIZE_T ElementCount) throw() {
+    ATLASSERT(m_pData == NULL);
+    SIZE_T nBytes = ElementCount * sizeof(T);
+    m_pData = static_cast<T *>(m_pMalloc->Alloc(nBytes));
+    if (m_pData == NULL)
+      return false;
+    return true;
+  }
+
+  void AllocateBytes(_In_ SIZE_T ByteCount) throw() {
+    if (m_pData)
+      m_pMalloc->Free(m_pData);
+    m_pData = static_cast<T *>(m_pMalloc->Alloc(ByteCount));
+  }
+
+  // Attach to an existing pointer (takes ownership)
+  void Attach(_In_ T *pData) throw() {
+    m_pMalloc->Free(m_pData);
+    m_pData = pData;
+  }
+
+  // Detach the pointer (releases ownership)
+  T *Detach() throw() {
+    T *pTemp = m_pData;
+    m_pData = NULL;
+    return pTemp;
+  }
+
+  // Free the memory pointed to, and set the pointer to NULL
+  void Free() throw() {
+    m_pMalloc->Free(m_pData);
+    m_pData = NULL;
+  }
+};
+
+void ReadBinaryFile(_In_opt_ IMalloc *pMalloc,
+                    _In_z_ LPCWSTR pFileName,
+                    _Outptr_result_bytebuffer_(*pDataSize) void **ppData,
+                    _Out_ DWORD *pDataSize);
 void ReadBinaryFile(_In_z_ LPCWSTR pFileName,
                     _Outptr_result_bytebuffer_(*pDataSize) void **ppData,
                     _Out_ DWORD *pDataSize);
@@ -30,8 +123,13 @@ void WriteBinaryFile(_In_z_ LPCWSTR pFileName,
 UINT32 DxcCodePageFromBytes(_In_count_(byteLen) const char *bytes,
                             size_t byteLen) throw();
 
+HRESULT
+DxcCreateBlobFromFile(_In_opt_ IMalloc *pMalloc, LPCWSTR pFileName,
+                      _In_opt_ UINT32 *pCodePage,
+                      _COM_Outptr_ IDxcBlobEncoding **pBlobEncoding) throw();
+
 HRESULT DxcCreateBlobFromFile(LPCWSTR pFileName, _In_opt_ UINT32 *pCodePage,
-                              _COM_Outptr_ IDxcBlobEncoding **pBlobEncoding) throw();
+                              _COM_Outptr_ IDxcBlobEncoding **ppBlobEncoding) throw();
 
 // Given a blob, creates a subrange view.
 HRESULT DxcCreateBlobFromBlob(_In_ IDxcBlob *pBlob, UINT32 offset,
@@ -49,7 +147,11 @@ DxcCreateBlobOnHeapCopy(_In_bytecount_(size) LPCVOID pData, UINT32 size,
 // Given a blob, creates a new instance with a specific code page set.
 HRESULT
 DxcCreateBlobWithEncodingSet(_In_ IDxcBlob *pBlob, UINT32 codePage,
-                             _COM_Outptr_ IDxcBlobEncoding **pBlobEncoding) throw();
+                             _COM_Outptr_ IDxcBlobEncoding **ppBlobEncoding) throw();
+HRESULT
+DxcCreateBlobWithEncodingSet(
+    _In_ IMalloc *pMalloc, _In_ IDxcBlob *pBlob, UINT32 codePage,
+    _COM_Outptr_ IDxcBlobEncoding **ppBlobEncoding) throw();
 
 HRESULT DxcCreateBlobWithEncodingFromPinned(
     _In_bytecount_(size) LPCVOID pText, UINT32 size, UINT32 codePage,
@@ -65,6 +167,7 @@ DxcCreateBlobWithEncodingOnHeap(_In_bytecount_(size) LPCVOID pText, UINT32 size,
                                 UINT32 codePage,
                                 _COM_Outptr_ IDxcBlobEncoding **pBlobEncoding) throw();
 
+// Should rename this 'OnHeap' to be 'OnMalloc', change callers to pass arg. Using TLS.
 HRESULT
 DxcCreateBlobWithEncodingOnHeapCopy(
     _In_bytecount_(size) LPCVOID pText, UINT32 size, UINT32 codePage,
@@ -75,6 +178,11 @@ DxcCreateBlobWithEncodingOnMalloc(
   _In_bytecount_(size) LPCVOID pText, IMalloc *pIMalloc, UINT32 size, UINT32 codePage,
   _COM_Outptr_ IDxcBlobEncoding **pBlobEncoding) throw();
 
+HRESULT
+DxcCreateBlobWithEncodingOnMallocCopy(
+  _In_ IMalloc *pIMalloc, _In_bytecount_(size) LPCVOID pText, UINT32 size, UINT32 codePage,
+  _COM_Outptr_ IDxcBlobEncoding **pBlobEncoding) throw();
+
 HRESULT DxcGetBlobAsUtf8(_In_ IDxcBlob *pBlob,
                          _COM_Outptr_ IDxcBlobEncoding **pBlobEncoding) throw();
 HRESULT
@@ -82,8 +190,9 @@ DxcGetBlobAsUtf8NullTerm(
     _In_ IDxcBlob *pBlob,
     _COM_Outptr_ IDxcBlobEncoding **ppBlobEncoding) throw();
 
-HRESULT DxcGetBlobAsUtf16(_In_ IDxcBlob *pBlob,
-                          _COM_Outptr_ IDxcBlobEncoding **pBlobEncoding) throw();
+HRESULT
+DxcGetBlobAsUtf16(_In_ IDxcBlob *pBlob, _In_ IMalloc *pMalloc,
+                  _COM_Outptr_ IDxcBlobEncoding **pBlobEncoding) throw();
 
 bool IsBlobNullOrEmpty(_In_opt_ IDxcBlob *pBlob) throw();
 

+ 48 - 1
include/dxc/Support/Global.h

@@ -5,7 +5,7 @@
 // This file is distributed under the University of Illinois Open Source     //
 // License. See LICENSE.TXT for details.                                     //
 //                                                                           //
-// Provides important declarations global to all DX Compiler code.          //
+// Provides important declarations global to all DX Compiler code.           //
 //                                                                           //
 ///////////////////////////////////////////////////////////////////////////////
 
@@ -25,6 +25,53 @@ typedef _Return_type_success_(return >= 0) long HRESULT;
 #include <stdarg.h>
 #include "dxc/Support/exception.h"
 
+///////////////////////////////////////////////////////////////////////////////
+// Memory allocation support.
+//
+// This mechanism ties into the C++ new and delete operators.
+//
+// Other allocators may be used in specific situations, eg sub-allocators or
+// the COM allocator for interop. This is the preferred allocator in general,
+// however, as it eventually allows the library user to specify their own.
+//
+
+struct IMalloc;
+
+// Used by DllMain to set up and tear down per-thread tracking.
+HRESULT DxcInitThreadMalloc() throw();
+void DxcCleanupThreadMalloc() throw();
+
+// Used by APIs entry points to set up per-thread/invocation allocator.
+// Setting the IMalloc on the thread increases the reference count,
+// clearing it decreases it.
+void DxcClearThreadMalloc() throw();
+void DxcSetThreadMalloc(IMalloc *pMalloc) throw();
+void DxcSetThreadMallocOrDefault(IMalloc *pMalloc) throw();
+
+// Swapping does not AddRef or Release new or prior. The pattern is to keep both alive,
+// either in TLS, or on the stack to restore later. The returned value is the effective
+// IMalloc also available in TLS.
+IMalloc *DxcSwapThreadMalloc(IMalloc *pMalloc, IMalloc **ppPrior) throw();
+IMalloc *DxcSwapThreadMallocOrDefault(IMalloc *pMalloc, IMalloc **ppPrior) throw();
+
+// Used to retrieve the current invocation's allocator or perform an alloc/free/realloc.
+IMalloc *DxcGetThreadMallocNoRef() throw();
+_Ret_maybenull_ _Post_writable_byte_size_(nBytes) void *DxcThreadAlloc(size_t nBytes) throw();
+void DxcThreadFree(void *) throw();
+
+struct DxcThreadMalloc {
+  DxcThreadMalloc(IMalloc *pMallocOrNull) throw() {
+    p = DxcSwapThreadMallocOrDefault(pMallocOrNull, &pPrior);
+  }
+  ~DxcThreadMalloc() {
+    DxcSwapThreadMalloc(pPrior, nullptr);
+  }
+  IMalloc *p;
+  IMalloc *pPrior;
+};
+
+///////////////////////////////////////////////////////////////////////////////
+// Error handling support.
 namespace std { class error_code; }
 void CheckLLVMErrorCode(const std::error_code &ec);
 

+ 2 - 0
include/dxc/Support/HLSLOptions.h

@@ -52,6 +52,8 @@ enum ID {
   };
 
 const llvm::opt::OptTable *getHlslOptTable();
+std::error_code initHlslOptTable();
+void cleanupHlslOptTable();
 
 ///////////////////////////////////////////////////////////////////////////////
 // Helper classes to deal with options.

+ 1 - 0
include/dxc/Support/WinIncludes.h

@@ -32,6 +32,7 @@
 #include <atlbase.h> // atlbase.h needs to come before strsafe.h
 #include <strsafe.h>
 #include <intsafe.h>
+#include <ObjIdl.h>
 
 /// Swap two ComPtr classes.
 template <class T> void swap(CComHeapPtr<T> &a, CComHeapPtr<T> &b) {

+ 12 - 8
include/dxc/Support/dxcapi.impl.h

@@ -34,15 +34,18 @@ public:
 
 class DxcOperationResult : public IDxcOperationResult {
 private:
-  DXC_MICROCOM_REF_FIELD(m_dwRef)
+  DXC_MICROCOM_TM_REF_FIELDS()
 
-  DxcOperationResult(_In_opt_ IDxcBlob *pResultBlob,
-    _In_opt_ IDxcBlobEncoding *pErrorBlob, HRESULT status)
-    : m_dwRef(0), m_status(status), m_result(pResultBlob),
-    m_errors(pErrorBlob) {}
+  void Init(_In_opt_ IDxcBlob *pResultBlob,
+            _In_opt_ IDxcBlobEncoding *pErrorBlob, HRESULT status) {
+    m_status = status;
+    m_result = pResultBlob;
+    m_errors = pErrorBlob;
+  }
 
 public:
-  DXC_MICROCOM_ADDREF_RELEASE_IMPL(m_dwRef)
+  DXC_MICROCOM_TM_ADDREF_RELEASE_IMPL()
+  DXC_MICROCOM_TM_CTOR(DxcOperationResult)
 
   HRESULT m_status;
   CComPtr<IDxcBlob> m_result;
@@ -57,8 +60,9 @@ public:
                                              HRESULT status,
                                              _COM_Outptr_ IDxcOperationResult **ppResult) {
     *ppResult = nullptr;
-    CComPtr<DxcOperationResult> result = new (std::nothrow) DxcOperationResult(pResultBlob, pErrorBlob, status);
-    if (result.p == nullptr) return E_OUTOFMEMORY;
+    CComPtr<DxcOperationResult> result = DxcOperationResult::Alloc(DxcGetThreadMallocNoRef());
+    IFROOM(result.p);
+    result->Init(pResultBlob, pErrorBlob, status);
     *ppResult = result.Detach();
     return S_OK;
   }

+ 33 - 2
include/dxc/Support/dxcapi.use.h

@@ -21,6 +21,7 @@ class DxcDllSupport {
 protected:
   HMODULE m_dll;
   DxcCreateInstanceProc m_createFn;
+  DxcCreateInstance2Proc m_createFn2;
 
   HRESULT InitializeInternal(LPCWSTR dllName, LPCSTR fnName) {
     if (m_dll != nullptr) return S_OK;
@@ -36,16 +37,28 @@ protected:
       return hr;
     }
 
+    // Only basic functions used to avoid requiring additional headers.
+    m_createFn2 = nullptr;
+    char fnName2[128];
+    size_t s = strlen(fnName);
+    if (s < sizeof(fnName2) - 2) {
+      memcpy(fnName2, fnName, s);
+      fnName2[s] = '2';
+      fnName2[s + 1] = '\0';
+      m_createFn2 = (DxcCreateInstance2Proc)GetProcAddress(m_dll, fnName2);
+    }
+
     return S_OK;
   }
 
 public:
-  DxcDllSupport() : m_dll(nullptr), m_createFn(nullptr) {
+  DxcDllSupport() : m_dll(nullptr), m_createFn(nullptr), m_createFn2(nullptr) {
   }
 
   DxcDllSupport(DxcDllSupport&& other) {
     m_dll = other.m_dll; other.m_dll = nullptr;
-    m_createFn = other.m_createFn; other.m_dll = nullptr;
+    m_createFn = other.m_createFn; other.m_createFn = nullptr;
+    m_createFn2 = other.m_createFn2; other.m_createFn2 = nullptr;
   }
 
   ~DxcDllSupport() {
@@ -72,6 +85,23 @@ public:
     return hr;
   }
 
+  template <typename TInterface>
+  HRESULT CreateInstance2(IMalloc *pMalloc, REFCLSID clsid, _Outptr_ TInterface** pResult) {
+    return CreateInstance2(pMalloc, clsid, __uuidof(TInterface), (IUnknown**)pResult);
+  }
+
+  HRESULT CreateInstance2(IMalloc *pMalloc, REFCLSID clsid, REFIID riid, _Outptr_ IUnknown **pResult) {
+    if (pResult == nullptr) return E_POINTER;
+    if (m_dll == nullptr) return E_FAIL;
+    if (m_createFn2 == nullptr) return E_FAIL;
+    HRESULT hr = m_createFn2(pMalloc, clsid, riid, (LPVOID*)pResult);
+    return hr;
+  }
+
+  bool HasCreateWithMalloc() const {
+    return m_createFn2 != nullptr;
+  }
+
   bool IsEnabled() const {
     return m_dll != nullptr;
   }
@@ -79,6 +109,7 @@ public:
   void Cleanup() {
     if (m_dll != nullptr) {
       m_createFn = nullptr;
+      m_createFn2 = nullptr;
       FreeLibrary(m_dll);
       m_dll = nullptr;
     }

+ 2 - 3
include/dxc/Support/dxcfilesystem.h

@@ -42,9 +42,8 @@ public:
   virtual HRESULT RegisterOutputStream(LPCWSTR pName, IStream *pStream) = 0;
 };
 
-HRESULT
+DxcArgsFileSystem *
 CreateDxcArgsFileSystem(_In_ IDxcBlob *pSource, _In_ LPCWSTR pSourceName,
-                        _In_opt_ IDxcIncludeHandler *pIncludeHandler,
-                        _Outptr_ DxcArgsFileSystem **ppResult) throw();
+                        _In_opt_ IDxcIncludeHandler *pIncludeHandler);
 
 } // namespace dxcutil

+ 36 - 6
include/dxc/Support/microcom.h

@@ -73,19 +73,49 @@ public:
   }
 };
 
-#define DXC_MICROCOM_REF_FIELD(m_dwRef) volatile ULONG m_dwRef;
-    
-#define DXC_MICROCOM_ADDREF_RELEASE_IMPL(m_dwRef) \
-    bool HasSingleRef() { return 1 == m_dwRef; } \
-    ULONG STDMETHODCALLTYPE AddRef()  {\
+#define DXC_MICROCOM_REF_FIELD(m_dwRef) volatile ULONG m_dwRef = 0;
+#define DXC_MICROCOM_ADDREF_IMPL(m_dwRef) \
+    ULONG STDMETHODCALLTYPE AddRef() {\
         return InterlockedIncrement(&m_dwRef); \
-    } \
+    }
+#define DXC_MICROCOM_ADDREF_RELEASE_IMPL(m_dwRef) \
+    DXC_MICROCOM_ADDREF_IMPL(m_dwRef) \
     ULONG STDMETHODCALLTYPE Release() { \
         ULONG result = InterlockedDecrement(&m_dwRef); \
         if (result == 0) delete this; \
         return result; \
     }
 
+template <typename T, typename... Args>
+inline T *CreateOnMalloc(IMalloc * pMalloc, Args&&... args) {
+  void *P = pMalloc->Alloc(sizeof(T)); \
+  if (P) new (P)T(pMalloc, std::forward<Args>(args)...); \
+  return (T *)P; \
+}
+
+// The "TM" version keep an IMalloc field that, if not null, indicate
+// ownership of 'this' and of any allocations used during release.
+#define DXC_MICROCOM_TM_REF_FIELDS() \
+  volatile ULONG m_dwRef = 0;\
+  CComPtr<IMalloc> m_pMalloc;
+#define DXC_MICROCOM_TM_ADDREF_RELEASE_IMPL() \
+    DXC_MICROCOM_ADDREF_IMPL(m_dwRef) \
+    ULONG STDMETHODCALLTYPE Release() { \
+      ULONG result = InterlockedDecrement(&m_dwRef); \
+      if (result == 0) { \
+        CComPtr<IMalloc> pTmp(m_pMalloc); \
+        DxcThreadMalloc M(pTmp); delete this; \
+      } \
+      return result; \
+    }
+#define DXC_MICROCOM_TM_CTOR(T) \
+  T(IMalloc *pMalloc) : m_dwRef(0), m_pMalloc(pMalloc) { } \
+  static T* Alloc(IMalloc *pMalloc) { \
+    void *P = pMalloc->Alloc(sizeof(T)); \
+    try { if (P) new (P)T(pMalloc); } catch (...) { operator delete(P); throw; } \
+    return (T *)P; \
+  }
+
 /// <summary>
 /// Provides a QueryInterface implementation for a class that supports
 /// any number of interfaces in addition to IUnknown.

+ 15 - 0
include/dxc/dxcapi.h

@@ -17,6 +17,7 @@
 #define DXC_API_IMPORT __declspec(dllimport)
 #endif
 
+struct IMalloc;
 struct IDxcIncludeHandler;
 
 /// <summary>
@@ -42,6 +43,13 @@ typedef HRESULT (__stdcall *DxcCreateInstanceProc)(
     _Out_ LPVOID*   ppv
 );
 
+typedef HRESULT(__stdcall *DxcCreateInstance2Proc)(
+  _In_ IMalloc    *pMalloc,
+  _In_ REFCLSID   rclsid,
+  _In_ REFIID     riid,
+  _Out_ LPVOID*   ppv
+  );
+
 /// <summary>
 /// Creates a single uninitialized object of the class associated with a specified CLSID.
 /// </summary>
@@ -65,6 +73,13 @@ DXC_API_IMPORT HRESULT __stdcall DxcCreateInstance(
   _Out_ LPVOID*   ppv
   );
 
+DXC_API_IMPORT HRESULT __stdcall DxcCreateInstance2(
+  _In_ IMalloc    *pMalloc,
+  _In_ REFCLSID   rclsid,
+  _In_ REFIID     riid,
+  _Out_ LPVOID*   ppv
+);
+
 
 // IDxcBlob is an alias of ID3D10Blob and ID3DBlob
 struct __declspec(uuid("8BA5FB08-5195-40e2-AC58-0D989C3A0102"))

+ 18 - 5
include/llvm/ADT/DenseMap.h

@@ -676,13 +676,16 @@ private:
   }
 
   bool allocateBuckets(unsigned Num) {
-    NumBuckets = Num;
-    if (NumBuckets == 0) {
+    // HLSL Change Starts - reorder statement to clean up properly on OOM
+    if (Num == 0) {
+      NumBuckets = 0;
       Buckets = nullptr;
       return false;
     }
 
-    Buckets = static_cast<BucketT*>(operator new(sizeof(BucketT) * NumBuckets));
+    Buckets = static_cast<BucketT*>(operator new(sizeof(BucketT) * Num));
+    NumBuckets = Num;
+    // HLSL Change Ends - reorder statement to clean up properly on OOM
     return true;
   }
 };
@@ -876,8 +879,8 @@ public:
 
       // Now make this map use the large rep, and move all the entries back
       // into it.
-      Small = false;
-      new (getLargeRep()) LargeRep(allocateBuckets(AtLeast));
+      new (getLargeRepForTransition()) LargeRep(allocateBuckets(AtLeast));
+      Small = false; // HLSL Change - used to be prior to allocation
       this->moveFromOldBuckets(TmpBegin, TmpEnd);
       return;
     }
@@ -953,6 +956,16 @@ private:
     return const_cast<LargeRep *>(
       const_cast<const SmallDenseMap *>(this)->getLargeRep());
   }
+  // HLSL Change Starts - avoid Small check, as we are in the process of transitioning
+  const LargeRep *getLargeRepForTransition() const {
+    // Note, same rule about aliasing as with getInlineBuckets.
+    return reinterpret_cast<const LargeRep *>(storage.buffer);
+  }
+  LargeRep *getLargeRepForTransition() {
+    return const_cast<LargeRep *>(
+      const_cast<const SmallDenseMap *>(this)->getLargeRepForTransition());
+  }
+  // HLSL Change Ends
 
   const BucketT *getBuckets() const {
     return Small ? getInlineBuckets() : getLargeRep()->Buckets;

+ 1 - 1
include/llvm/ADT/Statistic.h

@@ -52,7 +52,7 @@ public:
   // Allow use of this class as the value itself.
   operator unsigned() const { return Value; }
 
-#if !defined(NDEBUG) || defined(LLVM_ENABLE_STATS)
+#if (!defined(NDEBUG) || defined(LLVM_ENABLE_STATS)) && 0 // HLSL Change - always disable, shouldn't do process-wide alloc in compile
    const Statistic &operator=(unsigned Val) {
     Value = Val;
     return init();

+ 1 - 0
include/llvm/Analysis/CallGraph.h

@@ -103,6 +103,7 @@ class CallGraph {
   /// functions that it calls.
   void addToCallGraph(Function *F);
 
+  void reset(); // HLSL Change
 public:
   CallGraph(Module &M);
   ~CallGraph();

+ 2 - 0
include/llvm/Bitcode/BitstreamWriter.h

@@ -92,8 +92,10 @@ public:
     : Out(O), CurBit(0), CurValue(0), CurCodeSize(2) {}
 
   ~BitstreamWriter() {
+#if 0 // HLSL Change - these are not true when recovering from OOM
     assert(CurBit == 0 && "Unflushed data remaining");
     assert(BlockScope.empty() && CurAbbrevs.empty() && "Block imbalance");
+#endif
   }
 
   /// \brief Retrieve the current position in the stream, in bits.

+ 2 - 2
include/llvm/IR/LegacyPassManagers.h

@@ -143,7 +143,7 @@ public:
 /// suitable manager.
 class PMStack {
 public:
-  typedef std::vector<PMDataManager *>::const_reverse_iterator iterator;
+  typedef llvm::SmallVector<PMDataManager *, 2>::const_reverse_iterator iterator; // HLSL Change - SmallVector rather than vector
   iterator begin() const { return S.rbegin(); }
   iterator end() const { return S.rend(); }
 
@@ -155,7 +155,7 @@ public:
   void dump() const;
 
 private:
-  std::vector<PMDataManager *> S;
+  llvm::SmallVector<PMDataManager *, 2> S; // HLSL Change - SmallVector rather than vector
 };
 
 

+ 5 - 2
include/llvm/IR/Metadata.h

@@ -755,8 +755,9 @@ protected:
   void operator delete(void *Mem);
 
   /// \brief Required by std, but never called.
-  void operator delete(void *, unsigned) {
-    llvm_unreachable("Constructor throws?");
+  void operator delete(void *Mem, unsigned) {
+    //llvm_unreachable("Constructor throws?"); // HLSL Change - why, yes; yes it does (under OOM)
+    MDNode::operator delete(Mem);
   }
 
   /// \brief Required by std, but never called.
@@ -903,7 +904,9 @@ private:
   /// \pre \a isTemporary().
   void makeDistinct();
 
+public: // HLSL Change - make deleteAsSubclass accessible
   void deleteAsSubclass();
+private:
   MDNode *uniquify();
   void eraseFromStore();
 

+ 2 - 3
include/llvm/IR/User.h

@@ -77,9 +77,8 @@ public:
   /// \brief Free memory allocated for User and Use objects.
   void operator delete(void *Usr);
   /// \brief Placement delete - required by std, but never called.
-  void operator delete(void*, unsigned) {
-    llvm_unreachable("Constructor throws?");
-  }
+  void operator delete(void*, unsigned);
+    // llvm_unreachable("Constructor throws?"); - HLSL Change: it does on OOM
   /// \brief Placement delete - required by std, but never called.
   void operator delete(void*, unsigned, bool) {
     llvm_unreachable("Constructor throws?");

+ 2 - 2
include/llvm/Support/CommandLine.h

@@ -169,7 +169,7 @@ public:
 };
 
 // The general Option Category (used as default category).
-extern OptionCategory GeneralCategory;
+extern OptionCategory *GeneralCategory; // HLSL Change - GeneralCategory is now a pointer
 
 //===----------------------------------------------------------------------===//
 // Option Base class
@@ -251,7 +251,7 @@ protected:
       : NumOccurrences(0), Occurrences(OccurrencesFlag), Value(0),
         HiddenFlag(Hidden), Formatting(NormalFormatting), Misc(0), Position(0),
         AdditionalVals(0), ArgStr(""), HelpStr(""), ValueStr(""),
-        Category(&GeneralCategory), FullyInitialized(false) {}
+        Category(GeneralCategory), FullyInitialized(false) {} // HLSL Change - not GeneralCategory
 
   inline void setNumAdditionalVals(unsigned n) { AdditionalVals = n; }
 

+ 4 - 0
include/llvm/Support/Mutex.h

@@ -71,7 +71,11 @@ namespace llvm
     /// @{
     private:
 #if defined(LLVM_ENABLE_THREADS) && LLVM_ENABLE_THREADS != 0
+#if 0 // HLSL Change
       void* data_; ///< We don't know what the data will be
+#else
+      char data_[40]; // C_ASSERT this is CRITICAL_SECTION-sized
+#endif // HLSL Change
 #endif
 
     /// @}

+ 9 - 2
include/llvm/Support/raw_ostream.h

@@ -91,6 +91,13 @@ public:
   /// tell - Return the current offset with the file.
   uint64_t tell() const { return current_pos() + GetNumBytesInBuffer(); }
 
+  // HLSL Change Starts - needed to clean up properly
+  virtual void close() { flush(); }
+  virtual bool has_error() const { return false; }
+  virtual void clear_error() { }
+  // HLSL Change Ends
+
+
   //===--------------------------------------------------------------------===//
   // Configuration Interface
   //===--------------------------------------------------------------------===//
@@ -427,7 +434,7 @@ public:
   /// output error has been encountered.
   /// This doesn't implicitly flush any pending output.  Also, it doesn't
   /// guarantee to detect all errors unless the stream has been closed.
-  bool has_error() const {
+  bool has_error() const override {
     return Error;
   }
 
@@ -440,7 +447,7 @@ public:
   ///    Unless explicitly silenced."
   ///      - from The Zen of Python, by Tim Peters
   ///
-  void clear_error() {
+  void clear_error() override {
     Error = false;
   }
 };

+ 4 - 0
lib/Analysis/DependenceAnalysis.cpp

@@ -107,9 +107,13 @@ STATISTIC(BanerjeeApplications, "Banerjee applications");
 STATISTIC(BanerjeeIndependence, "Banerjee independence");
 STATISTIC(BanerjeeSuccesses, "Banerjee successes");
 
+#if 0 // HLSL Change Starts - option pending
 static cl::opt<bool>
 Delinearize("da-delinearize", cl::init(false), cl::Hidden, cl::ZeroOrMore,
             cl::desc("Try to delinearize array references."));
+#else
+static const bool Delinearize = false;
+#endif // HLSL Change Ends
 
 //===----------------------------------------------------------------------===//
 // basics

+ 28 - 12
lib/Analysis/IPA/CallGraph.cpp

@@ -22,30 +22,46 @@ using namespace llvm;
 //
 
 CallGraph::CallGraph(Module &M)
-    : M(M), Root(nullptr), ExternalCallingNode(getOrInsertFunction(nullptr)),
-      CallsExternalNode(new CallGraphNode(nullptr)) {
-  // Add every function to the call graph.
-  for (Function &F : M)
-    addToCallGraph(&F);
-
-  // If we didn't find a main function, use the external call graph node
-  if (!Root)
-    Root = ExternalCallingNode;
+    : M(M), Root(nullptr), ExternalCallingNode(nullptr), // HLSL Change - no allocation here
+      CallsExternalNode(nullptr) {
+  try { // HLSL change - guard and reset
+    ExternalCallingNode = getOrInsertFunction(nullptr);
+    CallsExternalNode = new CallGraphNode(nullptr);
+    // Add every function to the call graph.
+    for (Function &F : M)
+      addToCallGraph(&F);
+
+    // If we didn't find a main function, use the external call graph node
+    if (!Root)
+      Root = ExternalCallingNode;
+  } catch (...) {
+    reset();
+    throw;
+  }
 }
 
-CallGraph::~CallGraph() {
+// HLSL Change Starts
+CallGraph::~CallGraph() { reset(); }
+void CallGraph::reset() {
+  // This function cleans up the CallGraph, called from the destructor or
+  // an under-construction instance.
+// HLSL Change Ends
   // CallsExternalNode is not in the function map, delete it explicitly.
-  CallsExternalNode->allReferencesDropped();
+  if (CallsExternalNode) // HLSL Change - guard
+    CallsExternalNode->allReferencesDropped();
   delete CallsExternalNode;
+  CallsExternalNode = nullptr;
 
 // Reset all node's use counts to zero before deleting them to prevent an
 // assertion from firing.
 #ifndef NDEBUG
   for (auto &I : FunctionMap)
-    I.second->allReferencesDropped();
+    if (I.second) // HLSL Change - this guard needed when slot is alloc'ed but not populated
+      I.second->allReferencesDropped();
 #endif
   for (auto &I : FunctionMap)
     delete I.second;
+  FunctionMap.clear();
 }
 
 void CallGraph::addToCallGraph(Function *F) {

+ 6 - 0
lib/Analysis/IPA/CallGraphSCCPass.cpp

@@ -31,8 +31,12 @@ using namespace llvm;
 
 #define DEBUG_TYPE "cgscc-passmgr"
 
+#if 0 // HLSL Change Starts - option pending
 static cl::opt<unsigned> 
 MaxIterations("max-cg-scc-iterations", cl::ReallyHidden, cl::init(4));
+#else
+static const unsigned MaxIterations = 4;
+#endif
 
 STATISTIC(MaxSCCIterations, "Maximum CGSCCPassMgr iterations on one SCC");
 
@@ -546,6 +550,7 @@ void CallGraphSCC::ReplaceNode(CallGraphNode *Old, CallGraphNode *New) {
 /// Assign pass manager to manage this pass.
 void CallGraphSCCPass::assignPassManager(PMStack &PMS,
                                          PassManagerType PreferredType) {
+  std::unique_ptr<CallGraphSCCPass> thisPtr(this); // HLSL Change
   // Find CGPassManager 
   while (!PMS.empty() &&
          PMS.top()->getPassManagerType() > PMT_CallGraphPassManager)
@@ -577,6 +582,7 @@ void CallGraphSCCPass::assignPassManager(PMStack &PMS,
     PMS.push(CGP);
   }
 
+  thisPtr.release();
   CGP->add(this);
 }
 

+ 10 - 1
lib/Analysis/LoopAccessAnalysis.cpp

@@ -27,6 +27,7 @@ using namespace llvm;
 
 #define DEBUG_TYPE "loop-accesses"
 
+#if 0 // HLSL Change Starts - option pending
 static cl::opt<unsigned, true>
 VectorizationFactor("force-vector-width", cl::Hidden,
                     cl::desc("Sets the SIMD width. Zero is autoselect."),
@@ -64,9 +65,17 @@ static cl::opt<unsigned> MaxInterestingDependence(
     cl::desc("Maximum number of interesting dependences collected by "
              "loop-access analysis (default = 100)"),
     cl::init(100));
+#else
+unsigned VectorizerParams::VectorizationInterleave;
+unsigned VectorizerParams::VectorizationFactor;
+unsigned VectorizerParams::RuntimeMemoryCheckThreshold = 8;
+static const unsigned MemoryCheckMergeThreshold = 100;
+const unsigned VectorizerParams::MaxVectorWidth = 64;
+static const unsigned MaxInterestingDependence = 100;
+#endif // HLSL Change Ends
 
 bool VectorizerParams::isInterleaveForced() {
-  return ::VectorizationInterleave.getNumOccurrences() > 0;
+  return false; // HLSL Change - instead of return ::VectorizationInterleave.getNumOccurrences() > 0;
 }
 
 void LoopAccessReport::emitAnalysis(const LoopAccessReport &Message,

+ 3 - 0
lib/Analysis/LoopInfo.cpp

@@ -43,9 +43,12 @@ static bool VerifyLoopInfo = true;
 #else
 static bool VerifyLoopInfo = false;
 #endif
+#if 0 // HLSL Change Starts - option pending
 static cl::opt<bool,true>
 VerifyLoopInfoX("verify-loop-info", cl::location(VerifyLoopInfo),
                 cl::desc("Verify loop info (time consuming)"));
+#else
+#endif // HLSL Change Ends
 
 // Loop identifier metadata name.
 static const char *const LoopMDName = "llvm.loop";

+ 2 - 0
lib/Analysis/LoopPass.cpp

@@ -354,6 +354,7 @@ void LoopPass::preparePassManager(PMStack &PMS) {
 /// Assign pass manager to manage this pass.
 void LoopPass::assignPassManager(PMStack &PMS,
                                  PassManagerType PreferredType) {
+  std::unique_ptr<LoopPass> thisPtr(this); // HLSL Change
   // Find LPPassManager
   while (!PMS.empty() &&
          PMS.top()->getPassManagerType() > PMT_LoopPassManager)
@@ -384,6 +385,7 @@ void LoopPass::assignPassManager(PMStack &PMS,
     PMS.push(LPPM);
   }
 
+  thisPtr.release(); // HLSL Change
   LPPM->add(this);
 }
 

+ 2 - 0
lib/Analysis/RegionPass.cpp

@@ -241,6 +241,7 @@ void RegionPass::preparePassManager(PMStack &PMS) {
 /// Assign pass manager to manage this pass.
 void RegionPass::assignPassManager(PMStack &PMS,
                                  PassManagerType PreferredType) {
+  std::unique_ptr<RegionPass> thisPtr(this); // HLSL Change
   // Find RGPassManager
   while (!PMS.empty() &&
          PMS.top()->getPassManagerType() > PMT_RegionPassManager)
@@ -272,6 +273,7 @@ void RegionPass::assignPassManager(PMStack &PMS,
     PMS.push(RGPM);
   }
 
+  thisPtr.release(); // HLSL Change
   RGPM->add(this);
 }
 

+ 5 - 0
lib/Analysis/ScalarEvolution.cpp

@@ -102,6 +102,7 @@ STATISTIC(NumTripCountsNotComputed,
 STATISTIC(NumBruteForceTripCountsComputed,
           "Number of loops with trip counts computed by force");
 
+#if 0 // HLSL Change Starts - option pending
 static cl::opt<unsigned>
 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
                         cl::desc("Maximum number of iterations SCEV will "
@@ -113,6 +114,10 @@ MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
 static cl::opt<bool>
 VerifySCEV("verify-scev",
            cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
+#else
+static const unsigned MaxBruteForceIterations = 100;
+static const bool VerifySCEV = false;
+#endif // HLSL Change Ends
 
 INITIALIZE_PASS_BEGIN(ScalarEvolution, "scalar-evolution",
                 "Scalar Evolution Analysis", false, true)

+ 4 - 0
lib/Analysis/ScopedNoAliasAA.cpp

@@ -46,8 +46,12 @@ using namespace llvm;
 // A handy option for disabling scoped no-alias functionality. The same effect
 // can also be achieved by stripping the associated metadata tags from IR, but
 // this option is sometimes more convenient.
+#if 0 // HLSL Change Starts - option pending
 static cl::opt<bool>
 EnableScopedNoAlias("enable-scoped-noalias", cl::init(true));
+#else
+static const bool EnableScopedNoAlias = true;
+#endif // HLSL Change Ends
 
 namespace {
 /// AliasScopeNode - This is a simple wrapper around an MDNode which provides

+ 4 - 0
lib/Analysis/TargetLibraryInfo.cpp

@@ -16,6 +16,7 @@
 #include "llvm/Support/CommandLine.h"
 using namespace llvm;
 
+#if 0 // HLSL Change Starts - option pending
 static cl::opt<TargetLibraryInfoImpl::VectorLibrary> ClVectorLibrary(
     "vector-library", cl::Hidden, cl::desc("Vector functions library"),
     cl::init(TargetLibraryInfoImpl::NoLibrary),
@@ -24,6 +25,9 @@ static cl::opt<TargetLibraryInfoImpl::VectorLibrary> ClVectorLibrary(
                clEnumValN(TargetLibraryInfoImpl::Accelerate, "Accelerate",
                           "Accelerate framework"),
                clEnumValEnd));
+#else
+static const TargetLibraryInfoImpl::VectorLibrary ClVectorLibrary = TargetLibraryInfoImpl::NoLibrary;
+#endif // HLSL Change Ends
 
 const char *const TargetLibraryInfoImpl::StandardNames[LibFunc::NumLibFuncs] = {
 #define TLI_DEFINE_STRING

+ 4 - 0
lib/Analysis/TypeBasedAliasAnalysis.cpp

@@ -135,7 +135,11 @@ using namespace llvm;
 // A handy option for disabling TBAA functionality. The same effect can also be
 // achieved by stripping the !tbaa tags from IR, but this option is sometimes
 // more convenient.
+#if 0 // HLSL Change Starts - option pending
 static cl::opt<bool> EnableTBAA("enable-tbaa", cl::init(true));
+#else
+static const bool EnableTBAA = true;
+#endif // HLSL Change Ends
 
 namespace {
   /// TBAANode - This is a simple wrapper around an MDNode which provides a

+ 9 - 1
lib/Analysis/ValueTracking.cpp

@@ -41,6 +41,7 @@ using namespace llvm::PatternMatch;
 
 const unsigned MaxDepth = 6;
 
+#if 0 // HLSL Change Starts - option pending
 /// Enable an experimental feature to leverage information about dominating
 /// conditions to compute known bits.  The individual options below control how
 /// hard we search.  The defaults are choosen to be fairly aggressive.  If you
@@ -52,7 +53,7 @@ static cl::opt<bool> EnableDomConditions("value-tracking-dom-conditions",
 // This is expensive, so we only do it for the top level query value.
 // (TODO: evaluate cost vs profit, consider higher thresholds)
 static cl::opt<unsigned> DomConditionsMaxDepth("dom-conditions-max-depth",
-                                               cl::Hidden, cl::init(1));
+                                         cl::Hidden, cl::init(1));
 
 /// How many dominating blocks should be scanned looking for dominating
 /// conditions?
@@ -68,6 +69,13 @@ static cl::opt<unsigned> DomConditionsMaxUses("dom-conditions-max-uses",
 // If true, don't consider only compares whose only use is a branch.
 static cl::opt<bool> DomConditionsSingleCmpUse("dom-conditions-single-cmp-use",
                                                cl::Hidden, cl::init(false));
+#else
+static const bool EnableDomConditions = false;
+static const unsigned DomConditionsMaxDepth = 1;
+static const unsigned DomConditionsMaxDomBlocks = 2000;
+static const unsigned DomConditionsMaxUses = 2000;
+static const bool DomConditionsSingleCmpUse = false;
+#endif // HLSL Change Ends
 
 /// Returns the bitwidth of the given scalar or pointer type (if unknown returns
 /// 0). For vector types, returns the element type's bitwidth.

+ 3 - 1
lib/Analysis/regioninfo.cpp

@@ -37,6 +37,7 @@ STATISTIC(numSimpleRegions, "The # of simple regions");
 
 // Always verify if expensive checking is enabled.
 
+#if 0 // HLSL Change Starts - option pending
 static cl::opt<bool,true>
 VerifyRegionInfoX(
   "verify-region-info",
@@ -55,7 +56,8 @@ static cl::opt<Region::PrintStyle, true> printStyleX("print-region-style",
     clEnumValN(Region::PrintRN, "rn",
                "print regions in detail with element_iterator"),
     clEnumValEnd));
-
+#else
+#endif // HLSL Change Ends
 
 //===----------------------------------------------------------------------===//
 // Region implementation

+ 54 - 54
lib/Bitcode/Writer/BitcodeWriter.cpp

@@ -335,11 +335,11 @@ static void WriteTypeTable(const ValueEnumerator &VE, BitstreamWriter &Stream) {
   uint64_t NumBits = VE.computeBitsRequiredForTypeIndicies();
 
   // Abbrev for TYPE_CODE_POINTER.
-  BitCodeAbbrev *Abbv = new BitCodeAbbrev();
+  IntrusiveRefCntPtr<BitCodeAbbrev> Abbv = new BitCodeAbbrev();
   Abbv->Add(BitCodeAbbrevOp(bitc::TYPE_CODE_POINTER));
   Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, NumBits));
   Abbv->Add(BitCodeAbbrevOp(0));  // Addrspace = 0
-  unsigned PtrAbbrev = Stream.EmitAbbrev(Abbv);
+  unsigned PtrAbbrev = Stream.EmitAbbrev(Abbv.get());
 
   // Abbrev for TYPE_CODE_FUNCTION.
   Abbv = new BitCodeAbbrev();
@@ -348,7 +348,7 @@ static void WriteTypeTable(const ValueEnumerator &VE, BitstreamWriter &Stream) {
   Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Array));
   Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, NumBits));
 
-  unsigned FunctionAbbrev = Stream.EmitAbbrev(Abbv);
+  unsigned FunctionAbbrev = Stream.EmitAbbrev(Abbv.get());
 
   // Abbrev for TYPE_CODE_STRUCT_ANON.
   Abbv = new BitCodeAbbrev();
@@ -357,14 +357,14 @@ static void WriteTypeTable(const ValueEnumerator &VE, BitstreamWriter &Stream) {
   Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Array));
   Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, NumBits));
 
-  unsigned StructAnonAbbrev = Stream.EmitAbbrev(Abbv);
+  unsigned StructAnonAbbrev = Stream.EmitAbbrev(Abbv.get());
 
   // Abbrev for TYPE_CODE_STRUCT_NAME.
   Abbv = new BitCodeAbbrev();
   Abbv->Add(BitCodeAbbrevOp(bitc::TYPE_CODE_STRUCT_NAME));
   Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Array));
   Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Char6));
-  unsigned StructNameAbbrev = Stream.EmitAbbrev(Abbv);
+  unsigned StructNameAbbrev = Stream.EmitAbbrev(Abbv.get());
 
   // Abbrev for TYPE_CODE_STRUCT_NAMED.
   Abbv = new BitCodeAbbrev();
@@ -373,7 +373,7 @@ static void WriteTypeTable(const ValueEnumerator &VE, BitstreamWriter &Stream) {
   Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Array));
   Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, NumBits));
 
-  unsigned StructNamedAbbrev = Stream.EmitAbbrev(Abbv);
+  unsigned StructNamedAbbrev = Stream.EmitAbbrev(Abbv.get());
 
   // Abbrev for TYPE_CODE_ARRAY.
   Abbv = new BitCodeAbbrev();
@@ -381,7 +381,7 @@ static void WriteTypeTable(const ValueEnumerator &VE, BitstreamWriter &Stream) {
   Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::VBR, 8));   // size
   Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, NumBits));
 
-  unsigned ArrayAbbrev = Stream.EmitAbbrev(Abbv);
+  unsigned ArrayAbbrev = Stream.EmitAbbrev(Abbv.get());
 
   // Emit an entry count so the reader can reserve space.
   TypeVals.push_back(TypeList.size());
@@ -633,7 +633,7 @@ static void WriteModuleInfo(const Module *M, const ValueEnumerator &VE,
   unsigned SimpleGVarAbbrev = 0;
   if (!M->global_empty()) {
     // Add an abbrev for common globals with no visibility or thread localness.
-    BitCodeAbbrev *Abbv = new BitCodeAbbrev();
+    IntrusiveRefCntPtr<BitCodeAbbrev> Abbv = new BitCodeAbbrev();
     Abbv->Add(BitCodeAbbrevOp(bitc::MODULE_CODE_GLOBALVAR));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed,
                               Log2_32_Ceil(MaxGlobalType+1)));
@@ -655,7 +655,7 @@ static void WriteModuleInfo(const Module *M, const ValueEnumerator &VE,
       Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed,
                                Log2_32_Ceil(SectionMap.size()+1)));
     // Don't bother emitting vis + thread local.
-    SimpleGVarAbbrev = Stream.EmitAbbrev(Abbv);
+    SimpleGVarAbbrev = Stream.EmitAbbrev(Abbv.get());
   }
 
   // Emit the global variable information.
@@ -1172,11 +1172,11 @@ static void WriteModuleMetadata(const Module *M,
   unsigned MDSAbbrev = 0;
   if (VE.hasMDString()) {
     // Abbrev for METADATA_STRING.
-    BitCodeAbbrev *Abbv = new BitCodeAbbrev();
+    IntrusiveRefCntPtr<BitCodeAbbrev> Abbv = new BitCodeAbbrev();
     Abbv->Add(BitCodeAbbrevOp(bitc::METADATA_STRING));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Array));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, 8));
-    MDSAbbrev = Stream.EmitAbbrev(Abbv);
+    MDSAbbrev = Stream.EmitAbbrev(Abbv.get());
   }
 
   // Initialize MDNode abbreviations.
@@ -1188,14 +1188,14 @@ static void WriteModuleMetadata(const Module *M,
     //
     // Assume the column is usually under 128, and always output the inlined-at
     // location (it's never more expensive than building an array size 1).
-    BitCodeAbbrev *Abbv = new BitCodeAbbrev();
+    IntrusiveRefCntPtr<BitCodeAbbrev> Abbv = new BitCodeAbbrev();
     Abbv->Add(BitCodeAbbrevOp(bitc::METADATA_LOCATION));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, 1));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::VBR, 6));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::VBR, 8));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::VBR, 6));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::VBR, 6));
-    DILocationAbbrev = Stream.EmitAbbrev(Abbv);
+    DILocationAbbrev = Stream.EmitAbbrev(Abbv.get());
   }
 
   if (VE.hasGenericDINode()) {
@@ -1203,7 +1203,7 @@ static void WriteModuleMetadata(const Module *M,
     //
     // Assume the column is usually under 128, and always output the inlined-at
     // location (it's never more expensive than building an array size 1).
-    BitCodeAbbrev *Abbv = new BitCodeAbbrev();
+    IntrusiveRefCntPtr<BitCodeAbbrev> Abbv = new BitCodeAbbrev();
     Abbv->Add(BitCodeAbbrevOp(bitc::METADATA_GENERIC_DEBUG));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, 1));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::VBR, 6));
@@ -1211,17 +1211,17 @@ static void WriteModuleMetadata(const Module *M,
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::VBR, 6));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Array));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::VBR, 6));
-    GenericDINodeAbbrev = Stream.EmitAbbrev(Abbv);
+    GenericDINodeAbbrev = Stream.EmitAbbrev(Abbv.get());
   }
 
   unsigned NameAbbrev = 0;
   if (!M->named_metadata_empty()) {
     // Abbrev for METADATA_NAME.
-    BitCodeAbbrev *Abbv = new BitCodeAbbrev();
+    IntrusiveRefCntPtr<BitCodeAbbrev> Abbv = new BitCodeAbbrev();
     Abbv->Add(BitCodeAbbrevOp(bitc::METADATA_NAME));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Array));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, 8));
-    NameAbbrev = Stream.EmitAbbrev(Abbv);
+    NameAbbrev = Stream.EmitAbbrev(Abbv.get());
   }
 
   SmallVector<uint64_t, 64> Record;
@@ -1377,30 +1377,30 @@ static void WriteConstants(unsigned FirstVal, unsigned LastVal,
   // If this is a constant pool for the module, emit module-specific abbrevs.
   if (isGlobal) {
     // Abbrev for CST_CODE_AGGREGATE.
-    BitCodeAbbrev *Abbv = new BitCodeAbbrev();
+    IntrusiveRefCntPtr<BitCodeAbbrev> Abbv = new BitCodeAbbrev();
     Abbv->Add(BitCodeAbbrevOp(bitc::CST_CODE_AGGREGATE));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Array));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, Log2_32_Ceil(LastVal+1)));
-    AggregateAbbrev = Stream.EmitAbbrev(Abbv);
+    AggregateAbbrev = Stream.EmitAbbrev(Abbv.get());
 
     // Abbrev for CST_CODE_STRING.
     Abbv = new BitCodeAbbrev();
     Abbv->Add(BitCodeAbbrevOp(bitc::CST_CODE_STRING));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Array));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, 8));
-    String8Abbrev = Stream.EmitAbbrev(Abbv);
+    String8Abbrev = Stream.EmitAbbrev(Abbv.get());
     // Abbrev for CST_CODE_CSTRING.
     Abbv = new BitCodeAbbrev();
     Abbv->Add(BitCodeAbbrevOp(bitc::CST_CODE_CSTRING));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Array));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, 7));
-    CString7Abbrev = Stream.EmitAbbrev(Abbv);
+    CString7Abbrev = Stream.EmitAbbrev(Abbv.get());
     // Abbrev for CST_CODE_CSTRING.
     Abbv = new BitCodeAbbrev();
     Abbv->Add(BitCodeAbbrevOp(bitc::CST_CODE_CSTRING));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Array));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Char6));
-    CString6Abbrev = Stream.EmitAbbrev(Abbv);
+    CString6Abbrev = Stream.EmitAbbrev(Abbv.get());
   }
 
   SmallVector<uint64_t, 64> Record;
@@ -2181,70 +2181,70 @@ static void WriteBlockInfo(const ValueEnumerator &VE, BitstreamWriter &Stream) {
   Stream.EnterBlockInfoBlock(2);
 
   { // 8-bit fixed-width VST_ENTRY/VST_BBENTRY strings.
-    BitCodeAbbrev *Abbv = new BitCodeAbbrev();
+    IntrusiveRefCntPtr<BitCodeAbbrev> Abbv = new BitCodeAbbrev();
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, 3));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::VBR, 8));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Array));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, 8));
     if (Stream.EmitBlockInfoAbbrev(bitc::VALUE_SYMTAB_BLOCK_ID,
-                                   Abbv) != VST_ENTRY_8_ABBREV)
+                                   Abbv.get()) != VST_ENTRY_8_ABBREV)
       llvm_unreachable("Unexpected abbrev ordering!");
   }
 
   { // 7-bit fixed width VST_ENTRY strings.
-    BitCodeAbbrev *Abbv = new BitCodeAbbrev();
+    IntrusiveRefCntPtr<BitCodeAbbrev> Abbv = new BitCodeAbbrev();
     Abbv->Add(BitCodeAbbrevOp(bitc::VST_CODE_ENTRY));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::VBR, 8));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Array));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, 7));
     if (Stream.EmitBlockInfoAbbrev(bitc::VALUE_SYMTAB_BLOCK_ID,
-                                   Abbv) != VST_ENTRY_7_ABBREV)
+                                   Abbv.get()) != VST_ENTRY_7_ABBREV)
       llvm_unreachable("Unexpected abbrev ordering!");
   }
   { // 6-bit char6 VST_ENTRY strings.
-    BitCodeAbbrev *Abbv = new BitCodeAbbrev();
+    IntrusiveRefCntPtr<BitCodeAbbrev> Abbv = new BitCodeAbbrev();
     Abbv->Add(BitCodeAbbrevOp(bitc::VST_CODE_ENTRY));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::VBR, 8));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Array));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Char6));
     if (Stream.EmitBlockInfoAbbrev(bitc::VALUE_SYMTAB_BLOCK_ID,
-                                   Abbv) != VST_ENTRY_6_ABBREV)
+                                   Abbv.get()) != VST_ENTRY_6_ABBREV)
       llvm_unreachable("Unexpected abbrev ordering!");
   }
   { // 6-bit char6 VST_BBENTRY strings.
-    BitCodeAbbrev *Abbv = new BitCodeAbbrev();
+    IntrusiveRefCntPtr<BitCodeAbbrev> Abbv = new BitCodeAbbrev();
     Abbv->Add(BitCodeAbbrevOp(bitc::VST_CODE_BBENTRY));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::VBR, 8));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Array));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Char6));
     if (Stream.EmitBlockInfoAbbrev(bitc::VALUE_SYMTAB_BLOCK_ID,
-                                   Abbv) != VST_BBENTRY_6_ABBREV)
+                                   Abbv.get()) != VST_BBENTRY_6_ABBREV)
       llvm_unreachable("Unexpected abbrev ordering!");
   }
 
 
 
   { // SETTYPE abbrev for CONSTANTS_BLOCK.
-    BitCodeAbbrev *Abbv = new BitCodeAbbrev();
+    IntrusiveRefCntPtr<BitCodeAbbrev> Abbv = new BitCodeAbbrev();
     Abbv->Add(BitCodeAbbrevOp(bitc::CST_CODE_SETTYPE));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed,
                               VE.computeBitsRequiredForTypeIndicies()));
     if (Stream.EmitBlockInfoAbbrev(bitc::CONSTANTS_BLOCK_ID,
-                                   Abbv) != CONSTANTS_SETTYPE_ABBREV)
+                                   Abbv.get()) != CONSTANTS_SETTYPE_ABBREV)
       llvm_unreachable("Unexpected abbrev ordering!");
   }
 
   { // INTEGER abbrev for CONSTANTS_BLOCK.
-    BitCodeAbbrev *Abbv = new BitCodeAbbrev();
+    IntrusiveRefCntPtr<BitCodeAbbrev> Abbv = new BitCodeAbbrev();
     Abbv->Add(BitCodeAbbrevOp(bitc::CST_CODE_INTEGER));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::VBR, 8));
     if (Stream.EmitBlockInfoAbbrev(bitc::CONSTANTS_BLOCK_ID,
-                                   Abbv) != CONSTANTS_INTEGER_ABBREV)
+                                   Abbv.get()) != CONSTANTS_INTEGER_ABBREV)
       llvm_unreachable("Unexpected abbrev ordering!");
   }
 
   { // CE_CAST abbrev for CONSTANTS_BLOCK.
-    BitCodeAbbrev *Abbv = new BitCodeAbbrev();
+    IntrusiveRefCntPtr<BitCodeAbbrev> Abbv = new BitCodeAbbrev();
     Abbv->Add(BitCodeAbbrevOp(bitc::CST_CODE_CE_CAST));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, 4));  // cast opc
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed,       // typeid
@@ -2252,21 +2252,21 @@ static void WriteBlockInfo(const ValueEnumerator &VE, BitstreamWriter &Stream) {
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::VBR, 8));    // value id
 
     if (Stream.EmitBlockInfoAbbrev(bitc::CONSTANTS_BLOCK_ID,
-                                   Abbv) != CONSTANTS_CE_CAST_Abbrev)
+                                   Abbv.get()) != CONSTANTS_CE_CAST_Abbrev)
       llvm_unreachable("Unexpected abbrev ordering!");
   }
   { // NULL abbrev for CONSTANTS_BLOCK.
-    BitCodeAbbrev *Abbv = new BitCodeAbbrev();
+    IntrusiveRefCntPtr<BitCodeAbbrev> Abbv = new BitCodeAbbrev();
     Abbv->Add(BitCodeAbbrevOp(bitc::CST_CODE_NULL));
     if (Stream.EmitBlockInfoAbbrev(bitc::CONSTANTS_BLOCK_ID,
-                                   Abbv) != CONSTANTS_NULL_Abbrev)
+                                   Abbv.get()) != CONSTANTS_NULL_Abbrev)
       llvm_unreachable("Unexpected abbrev ordering!");
   }
 
   // FIXME: This should only use space for first class types!
 
   { // INST_LOAD abbrev for FUNCTION_BLOCK.
-    BitCodeAbbrev *Abbv = new BitCodeAbbrev();
+    IntrusiveRefCntPtr<BitCodeAbbrev> Abbv = new BitCodeAbbrev();
     Abbv->Add(BitCodeAbbrevOp(bitc::FUNC_CODE_INST_LOAD));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::VBR, 6)); // Ptr
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed,    // dest ty
@@ -2274,73 +2274,73 @@ static void WriteBlockInfo(const ValueEnumerator &VE, BitstreamWriter &Stream) {
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::VBR, 4)); // Align
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, 1)); // volatile
     if (Stream.EmitBlockInfoAbbrev(bitc::FUNCTION_BLOCK_ID,
-                                   Abbv) != FUNCTION_INST_LOAD_ABBREV)
+                                   Abbv.get()) != FUNCTION_INST_LOAD_ABBREV)
       llvm_unreachable("Unexpected abbrev ordering!");
   }
   { // INST_BINOP abbrev for FUNCTION_BLOCK.
-    BitCodeAbbrev *Abbv = new BitCodeAbbrev();
+    IntrusiveRefCntPtr<BitCodeAbbrev> Abbv = new BitCodeAbbrev();
     Abbv->Add(BitCodeAbbrevOp(bitc::FUNC_CODE_INST_BINOP));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::VBR, 6)); // LHS
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::VBR, 6)); // RHS
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, 4)); // opc
     if (Stream.EmitBlockInfoAbbrev(bitc::FUNCTION_BLOCK_ID,
-                                   Abbv) != FUNCTION_INST_BINOP_ABBREV)
+                                   Abbv.get()) != FUNCTION_INST_BINOP_ABBREV)
       llvm_unreachable("Unexpected abbrev ordering!");
   }
   { // INST_BINOP_FLAGS abbrev for FUNCTION_BLOCK.
-    BitCodeAbbrev *Abbv = new BitCodeAbbrev();
+    IntrusiveRefCntPtr<BitCodeAbbrev> Abbv = new BitCodeAbbrev();
     Abbv->Add(BitCodeAbbrevOp(bitc::FUNC_CODE_INST_BINOP));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::VBR, 6)); // LHS
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::VBR, 6)); // RHS
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, 4)); // opc
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, 7)); // flags
     if (Stream.EmitBlockInfoAbbrev(bitc::FUNCTION_BLOCK_ID,
-                                   Abbv) != FUNCTION_INST_BINOP_FLAGS_ABBREV)
+                                   Abbv.get()) != FUNCTION_INST_BINOP_FLAGS_ABBREV)
       llvm_unreachable("Unexpected abbrev ordering!");
   }
   { // INST_CAST abbrev for FUNCTION_BLOCK.
-    BitCodeAbbrev *Abbv = new BitCodeAbbrev();
+    IntrusiveRefCntPtr<BitCodeAbbrev> Abbv = new BitCodeAbbrev();
     Abbv->Add(BitCodeAbbrevOp(bitc::FUNC_CODE_INST_CAST));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::VBR, 6));    // OpVal
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed,       // dest ty
                               VE.computeBitsRequiredForTypeIndicies()));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, 4));  // opc
     if (Stream.EmitBlockInfoAbbrev(bitc::FUNCTION_BLOCK_ID,
-                                   Abbv) != FUNCTION_INST_CAST_ABBREV)
+                                   Abbv.get()) != FUNCTION_INST_CAST_ABBREV)
       llvm_unreachable("Unexpected abbrev ordering!");
   }
 
   { // INST_RET abbrev for FUNCTION_BLOCK.
-    BitCodeAbbrev *Abbv = new BitCodeAbbrev();
+    IntrusiveRefCntPtr<BitCodeAbbrev> Abbv = new BitCodeAbbrev();
     Abbv->Add(BitCodeAbbrevOp(bitc::FUNC_CODE_INST_RET));
     if (Stream.EmitBlockInfoAbbrev(bitc::FUNCTION_BLOCK_ID,
-                                   Abbv) != FUNCTION_INST_RET_VOID_ABBREV)
+                                   Abbv.get()) != FUNCTION_INST_RET_VOID_ABBREV)
       llvm_unreachable("Unexpected abbrev ordering!");
   }
   { // INST_RET abbrev for FUNCTION_BLOCK.
-    BitCodeAbbrev *Abbv = new BitCodeAbbrev();
+    IntrusiveRefCntPtr<BitCodeAbbrev> Abbv = new BitCodeAbbrev();
     Abbv->Add(BitCodeAbbrevOp(bitc::FUNC_CODE_INST_RET));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::VBR, 6)); // ValID
     if (Stream.EmitBlockInfoAbbrev(bitc::FUNCTION_BLOCK_ID,
-                                   Abbv) != FUNCTION_INST_RET_VAL_ABBREV)
+                                   Abbv.get()) != FUNCTION_INST_RET_VAL_ABBREV)
       llvm_unreachable("Unexpected abbrev ordering!");
   }
   { // INST_UNREACHABLE abbrev for FUNCTION_BLOCK.
-    BitCodeAbbrev *Abbv = new BitCodeAbbrev();
+    IntrusiveRefCntPtr<BitCodeAbbrev> Abbv = new BitCodeAbbrev();
     Abbv->Add(BitCodeAbbrevOp(bitc::FUNC_CODE_INST_UNREACHABLE));
     if (Stream.EmitBlockInfoAbbrev(bitc::FUNCTION_BLOCK_ID,
-                                   Abbv) != FUNCTION_INST_UNREACHABLE_ABBREV)
+                                   Abbv.get()) != FUNCTION_INST_UNREACHABLE_ABBREV)
       llvm_unreachable("Unexpected abbrev ordering!");
   }
   {
-    BitCodeAbbrev *Abbv = new BitCodeAbbrev();
+    IntrusiveRefCntPtr<BitCodeAbbrev> Abbv = new BitCodeAbbrev();
     Abbv->Add(BitCodeAbbrevOp(bitc::FUNC_CODE_INST_GEP));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, 1));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, // dest ty
                               Log2_32_Ceil(VE.getTypes().size() + 1)));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Array));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::VBR, 6));
-    if (Stream.EmitBlockInfoAbbrev(bitc::FUNCTION_BLOCK_ID, Abbv) !=
+    if (Stream.EmitBlockInfoAbbrev(bitc::FUNCTION_BLOCK_ID, Abbv.get()) !=
         FUNCTION_INST_GEP_ABBREV)
       llvm_unreachable("Unexpected abbrev ordering!");
   }

+ 1 - 0
lib/DxcSupport/CMakeLists.txt

@@ -2,6 +2,7 @@
 # This file is distributed under the University of Illinois Open Source License. See LICENSE.TXT for details.
 add_llvm_library(LLVMDxcSupport
   dxcapi.use.cpp
+  dxcmem.cpp
   FileIOHelper.cpp
   Global.cpp
   HLSLOptions.cpp

+ 205 - 83
lib/DxcSupport/FileIOHelper.cpp

@@ -5,7 +5,7 @@
 // This file is distributed under the University of Illinois Open Source     //
 // License. See LICENSE.TXT for details.                                     //
 //                                                                           //
-
+// TODO: consider including an empty blob singleton (possibly UTF-8/16 too). //
 //                                                                           //
 ///////////////////////////////////////////////////////////////////////////////
 
@@ -22,38 +22,109 @@
 
 #define CP_UTF16 1200
 
+struct HeapMalloc : public IMalloc {
+public:
+  ULONG STDMETHODCALLTYPE AddRef() {
+    return 1;
+  }
+  ULONG STDMETHODCALLTYPE Release() {
+    return 1;
+  }
+  STDMETHODIMP QueryInterface(REFIID iid, void** ppvObject) {
+    return DoBasicQueryInterface<IMalloc>(this, iid, ppvObject);
+  }
+  virtual void *STDMETHODCALLTYPE Alloc(
+    /* [annotation][in] */
+    _In_  SIZE_T cb) {
+    return HeapAlloc(GetProcessHeap(), 0, cb);
+  }
+
+  virtual void *STDMETHODCALLTYPE Realloc(
+    /* [annotation][in] */
+    _In_opt_  void *pv,
+    /* [annotation][in] */
+    _In_  SIZE_T cb)
+  {
+    return HeapReAlloc(GetProcessHeap(), 0, pv, cb);
+  }
+
+  virtual void STDMETHODCALLTYPE Free(
+    /* [annotation][in] */
+    _In_opt_  void *pv)
+  {
+    HeapFree(GetProcessHeap(), 0, pv);
+  }
+
+
+  virtual SIZE_T STDMETHODCALLTYPE GetSize(
+    /* [annotation][in] */
+    _In_opt_ _Post_writable_byte_size_(return)  void *pv)
+  {
+    return HeapSize(GetProcessHeap(), 0, pv);
+  }
+
+  virtual int STDMETHODCALLTYPE DidAlloc(
+    /* [annotation][in] */
+    _In_opt_  void *pv)
+  {
+    return -1; // don't know
+  }
+
+
+  virtual void STDMETHODCALLTYPE HeapMinimize(void)
+  {
+  }
+};
+
+static HeapMalloc g_HeapMalloc;
+
 namespace hlsl {
 
+IMalloc *GetGlobalHeapMalloc() {
+  return &g_HeapMalloc;
+}
+
 _Use_decl_annotations_
-void ReadBinaryFile(LPCWSTR pFileName, void **ppData, DWORD *pDataSize) {
-  HANDLE hFile = CreateFileW(pFileName, GENERIC_READ, FILE_SHARE_READ, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, nullptr);
-  if(hFile == INVALID_HANDLE_VALUE) {
+void ReadBinaryFile(IMalloc *pMalloc, LPCWSTR pFileName, void **ppData,
+                    DWORD *pDataSize) {
+  HANDLE hFile = CreateFileW(pFileName, GENERIC_READ, FILE_SHARE_READ, NULL,
+                             OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, nullptr);
+  if (hFile == INVALID_HANDLE_VALUE) {
     IFT(HRESULT_FROM_WIN32(GetLastError()));
   }
+
   CHandle h(hFile);
 
   LARGE_INTEGER FileSize;
-  if(!GetFileSizeEx(hFile, &FileSize)) {
+  if (!GetFileSizeEx(hFile, &FileSize)) {
     IFT(HRESULT_FROM_WIN32(GetLastError()));
   }
-  if(FileSize.HighPart != 0) {
+  if (FileSize.HighPart != 0) {
     throw(hlsl::Exception(DXC_E_INPUT_FILE_TOO_LARGE, "input file is too large"));
   }
-  CComHeapPtr<char> pData;
-  if (!pData.AllocateBytes(FileSize.LowPart)) {
+
+  char *pData = (char *)pMalloc->Alloc(FileSize.LowPart);
+  if (!pData) {
     throw std::bad_alloc();
   }
 
   DWORD BytesRead;
-  if(!ReadFile(hFile, pData.m_pData, FileSize.LowPart, &BytesRead, nullptr)) {
-    IFT(HRESULT_FROM_WIN32(GetLastError()));
+  if (!ReadFile(hFile, pData, FileSize.LowPart, &BytesRead, nullptr)) {
+    HRESULT hr = HRESULT_FROM_WIN32(GetLastError());
+    pMalloc->Free(pData);
+    throw ::hlsl::Exception(hr);
   }
   DXASSERT(FileSize.LowPart == BytesRead, "ReadFile operation failed");
 
-  *ppData = pData.Detach();
+  *ppData = pData;
   *pDataSize = FileSize.LowPart;
 }
 
+_Use_decl_annotations_
+void ReadBinaryFile(LPCWSTR pFileName, void **ppData, DWORD *pDataSize) {
+  return ReadBinaryFile(GetGlobalHeapMalloc(), pFileName, ppData, pDataSize);
+}
+
 _Use_decl_annotations_
 void WriteBinaryFile(LPCWSTR pFileName, const void *pData, DWORD DataSize) {
   HANDLE hFile = CreateFileW(pFileName, GENERIC_WRITE, 0, NULL, CREATE_ALWAYS, FILE_ATTRIBUTE_NORMAL, nullptr);
@@ -103,18 +174,26 @@ UINT32 DxcCodePageFromBytes(const char *bytes, size_t byteLen) {
 
 class InternalDxcBlobEncoding : public IDxcBlobEncoding {
 private:
-  DXC_MICROCOM_REF_FIELD(m_dwRef)
+  DXC_MICROCOM_TM_REF_FIELDS() // an underlying m_pMalloc that owns this
   LPCVOID m_Buffer = nullptr;
-  IUnknown* m_Owner = nullptr; // IMalloc when MallocFree is true
+  IUnknown* m_Owner = nullptr; // IMalloc when MallocFree is true, owning the buffer
   SIZE_T m_BufferSize;
-  unsigned m_HeapFree : 1;
   unsigned m_EncodingKnown : 1;
   unsigned m_MallocFree : 1;
   UINT32 m_CodePage;
 public:
-  DXC_MICROCOM_ADDREF_RELEASE_IMPL(m_dwRef)
-  InternalDxcBlobEncoding() : m_dwRef(0) {
+  DXC_MICROCOM_ADDREF_IMPL(m_dwRef)
+  ULONG STDMETHODCALLTYPE Release() {
+    // Because blobs are also used by tests and utilities, we avoid using TLS.
+    ULONG result = InterlockedDecrement(&m_dwRef);
+    if (result == 0) {
+      CComPtr<IMalloc> pTmp(m_pMalloc);
+      this->~InternalDxcBlobEncoding();
+      pTmp->Free(this);
+    }
+    return result;
   }
+  DXC_MICROCOM_TM_CTOR(InternalDxcBlobEncoding)
   HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid, void **ppvObject) {
     return DoBasicQueryInterface<IDxcBlob, IDxcBlobEncoding>(this, iid, ppvObject);
   }
@@ -126,33 +205,20 @@ public:
     if (m_Owner != nullptr) {
       m_Owner->Release();
     }
-    if (m_HeapFree) {
-      CoTaskMemFree((LPVOID)m_Buffer);
-    }
   }
 
   static HRESULT
   CreateFromHeap(LPCVOID buffer, SIZE_T bufferSize, bool encodingKnown,
                  UINT32 codePage,
-                 _COM_Outptr_ InternalDxcBlobEncoding **pEncoding) {
-    *pEncoding = new (std::nothrow) InternalDxcBlobEncoding();
-    if (*pEncoding == nullptr) {
-      return E_OUTOFMEMORY;
-    }
-    (*pEncoding)->m_Buffer = buffer;
-    (*pEncoding)->m_BufferSize = bufferSize;
-    (*pEncoding)->m_HeapFree = 1;
-    (*pEncoding)->m_EncodingKnown = encodingKnown;
-    (*pEncoding)->m_MallocFree = 0;
-    (*pEncoding)->m_CodePage = codePage;
-    (*pEncoding)->AddRef();
-    return S_OK;
+                 _COM_Outptr_ InternalDxcBlobEncoding **ppEncoding) {
+    return CreateFromMalloc(buffer, DxcGetThreadMallocNoRef(), bufferSize,
+                            encodingKnown, codePage, ppEncoding);
   }
 
   static HRESULT
-  CreateFromBlob(_In_ IDxcBlob *pBlob, bool encodingKnown, UINT32 codePage,
+  CreateFromBlob(_In_ IDxcBlob *pBlob, _In_ IMalloc *pMalloc, bool encodingKnown, UINT32 codePage,
                  _COM_Outptr_ InternalDxcBlobEncoding **pEncoding) {
-    *pEncoding = new (std::nothrow) InternalDxcBlobEncoding();
+    *pEncoding = InternalDxcBlobEncoding::Alloc(pMalloc);
     if (*pEncoding == nullptr) {
       return E_OUTOFMEMORY;
     }
@@ -160,25 +226,25 @@ public:
     (*pEncoding)->m_Owner = pBlob;
     (*pEncoding)->m_Buffer = pBlob->GetBufferPointer();
     (*pEncoding)->m_BufferSize = pBlob->GetBufferSize();
-    (*pEncoding)->m_HeapFree = 0;
     (*pEncoding)->m_EncodingKnown = encodingKnown;
     (*pEncoding)->m_MallocFree = 0;
     (*pEncoding)->m_CodePage = codePage;
     (*pEncoding)->AddRef();
     return S_OK;
   }
+
   static HRESULT
   CreateFromMalloc(LPCVOID buffer, IMalloc *pIMalloc, SIZE_T bufferSize, bool encodingKnown,
     UINT32 codePage, _COM_Outptr_ InternalDxcBlobEncoding **pEncoding) {
-    *pEncoding = new (std::nothrow) InternalDxcBlobEncoding();
+    *pEncoding = InternalDxcBlobEncoding::Alloc(pIMalloc);
     if (*pEncoding == nullptr) {
+      *pEncoding = nullptr;
       return E_OUTOFMEMORY;
     }
     pIMalloc->AddRef();
     (*pEncoding)->m_Owner = pIMalloc;
     (*pEncoding)->m_Buffer = buffer;
     (*pEncoding)->m_BufferSize = bufferSize;
-    (*pEncoding)->m_HeapFree = 0;
     (*pEncoding)->m_EncodingKnown = encodingKnown;
     (*pEncoding)->m_MallocFree = 1;
     (*pEncoding)->m_CodePage = codePage;
@@ -207,12 +273,12 @@ public:
 
   // Relatively dangerous API. This means the buffer should be pinned for as
   // long as this object is alive.
-  void ClearFreeFlag() { m_HeapFree = 0; }
+  void ClearFreeFlag() { m_MallocFree = 0; }
 };
 
 static HRESULT CodePageBufferToUtf16(UINT32 codePage, LPCVOID bufferPointer,
                                      SIZE_T bufferSize,
-                                     CComHeapPtr<WCHAR> &utf16NewCopy,
+                                     CDxcMallocHeapPtr<WCHAR> &utf16NewCopy,
                                      _Out_ UINT32 *pConvertedCharCount) {
   *pConvertedCharCount = 0;
 
@@ -272,7 +338,7 @@ HRESULT DxcCreateBlobFromBlob(
     IFR(pBlobEncoding->GetEncoding(&encodingKnown, &codePage));
   }
   CComPtr<InternalDxcBlobEncoding> pCreated;
-  IFR(InternalDxcBlobEncoding::CreateFromBlob(pBlob, encodingKnown, codePage,
+  IFR(InternalDxcBlobEncoding::CreateFromBlob(pBlob, DxcGetThreadMallocNoRef(), encodingKnown, codePage,
                                               &pCreated));
   pCreated->AdjustPtrAndSize(offset, length);
   *ppResult = pCreated.Detach();
@@ -317,17 +383,18 @@ DxcCreateBlobOnHeapCopy(_In_bytecount_(size) LPCVOID pData, UINT32 size,
 }
 
 _Use_decl_annotations_
-HRESULT DxcCreateBlobFromFile(LPCWSTR pFileName, UINT32 *pCodePage,
-                              IDxcBlobEncoding **ppBlobEncoding) {
+HRESULT
+DxcCreateBlobFromFile(IMalloc *pMalloc, LPCWSTR pFileName, UINT32 *pCodePage,
+                      IDxcBlobEncoding **ppBlobEncoding) throw() {
   if (pFileName == nullptr || ppBlobEncoding == nullptr) {
     return E_POINTER;
   }
 
-  CComHeapPtr<char> pData;
+  LPVOID pData;
   DWORD dataSize;
   *ppBlobEncoding = nullptr;
   try {
-    ReadBinaryFile(pFileName, (void **)(&pData), &dataSize);
+    ReadBinaryFile(pMalloc, pFileName, &pData, &dataSize);
   }
   CATCH_CPP_RETURN_HRESULT();
 
@@ -335,30 +402,51 @@ HRESULT DxcCreateBlobFromFile(LPCWSTR pFileName, UINT32 *pCodePage,
   UINT32 codePage = (pCodePage != nullptr) ? *pCodePage : 0;
 
   InternalDxcBlobEncoding *internalEncoding;
-  HRESULT hr = InternalDxcBlobEncoding::CreateFromHeap(
-      pData, dataSize, known, codePage, &internalEncoding);
+  HRESULT hr = InternalDxcBlobEncoding::CreateFromMalloc(
+    pData, pMalloc, dataSize, known, codePage, &internalEncoding);
   if (SUCCEEDED(hr)) {
     *ppBlobEncoding = internalEncoding;
-    pData.Detach();
+  }
+  else {
+    pMalloc->Free(pData);
   }
   return hr;
 }
 
+_Use_decl_annotations_
+HRESULT DxcCreateBlobFromFile(LPCWSTR pFileName, UINT32 *pCodePage,
+                              IDxcBlobEncoding **ppBlobEncoding) {
+  CComPtr<IMalloc> pMalloc;
+  IFR(CoGetMalloc(1, &pMalloc));
+  return DxcCreateBlobFromFile(pMalloc, pFileName, pCodePage, ppBlobEncoding);
+}
+
 _Use_decl_annotations_
 HRESULT
-DxcCreateBlobWithEncodingSet(IDxcBlob *pBlob, UINT32 codePage,
-                             IDxcBlobEncoding **pBlobEncoding) {
-  *pBlobEncoding = nullptr;
+DxcCreateBlobWithEncodingSet(IMalloc *pMalloc, IDxcBlob *pBlob, UINT32 codePage,
+                             IDxcBlobEncoding **ppBlobEncoding) {
+  DXASSERT_NOMSG(pMalloc != nullptr);
+  DXASSERT_NOMSG(pBlob != nullptr);
+  DXASSERT_NOMSG(ppBlobEncoding != nullptr);
+  *ppBlobEncoding = nullptr;
 
   InternalDxcBlobEncoding *internalEncoding;
-  HRESULT hr = InternalDxcBlobEncoding::CreateFromBlob(pBlob, true, codePage,
-                                                       &internalEncoding);
+  HRESULT hr = InternalDxcBlobEncoding::CreateFromBlob(
+      pBlob, pMalloc, true, codePage, &internalEncoding);
   if (SUCCEEDED(hr)) {
-    *pBlobEncoding = internalEncoding;
+    *ppBlobEncoding = internalEncoding;
   }
   return hr;
 }
 
+_Use_decl_annotations_
+HRESULT
+DxcCreateBlobWithEncodingSet(IDxcBlob *pBlob, UINT32 codePage,
+                             IDxcBlobEncoding **ppBlobEncoding) {
+  return DxcCreateBlobWithEncodingSet(DxcGetThreadMallocNoRef(), pBlob,
+                                      codePage, ppBlobEncoding);
+}
+
 _Use_decl_annotations_
 HRESULT DxcCreateBlobWithEncodingFromPinned(LPCVOID pText, UINT32 size,
                                             UINT32 codePage,
@@ -426,8 +514,8 @@ DxcCreateBlobWithEncodingOnHeapCopy(LPCVOID pText, UINT32 size, UINT32 codePage,
   IDxcBlobEncoding **pBlobEncoding) {
   *pBlobEncoding = nullptr;
 
-  CComHeapPtr<char> heapCopy;
-  if (!heapCopy.AllocateBytes(size)) {
+  CDxcMallocHeapPtr<char> heapCopy(DxcGetThreadMallocNoRef());
+  if (!heapCopy.Allocate(size)) {
     return E_OUTOFMEMORY;
   }
   memcpy(heapCopy.m_pData, pText, size);
@@ -455,6 +543,23 @@ DxcCreateBlobWithEncodingOnMalloc(LPCVOID pText, IMalloc *pIMalloc, UINT32 size,
   return hr;
 }
 
+_Use_decl_annotations_
+HRESULT
+DxcCreateBlobWithEncodingOnMallocCopy(IMalloc *pIMalloc, LPCVOID pText, UINT32 size, UINT32 codePage,
+  IDxcBlobEncoding **ppBlobEncoding) {
+  *ppBlobEncoding = nullptr;
+  void *pData = pIMalloc->Alloc(size);
+  if (pData == nullptr)
+    return E_OUTOFMEMORY;
+  memcpy(pData, pText, size);
+  HRESULT hr = DxcCreateBlobWithEncodingOnMalloc(pData, pIMalloc, size, codePage, ppBlobEncoding);
+  if (FAILED(hr)) {
+    pIMalloc->Free(pData);
+    return hr;
+  }
+  return S_OK;
+}
+
 
 _Use_decl_annotations_
 HRESULT DxcGetBlobAsUtf8(IDxcBlob *pBlob, IDxcBlobEncoding **pBlobEncoding) {
@@ -488,7 +593,7 @@ HRESULT DxcGetBlobAsUtf8(IDxcBlob *pBlob, IDxcBlobEncoding **pBlobEncoding) {
   if (codePage == CP_UTF8) {
     // Reuse the underlying blob but create an object with the encoding known.
     InternalDxcBlobEncoding* internalEncoding;
-    hr = InternalDxcBlobEncoding::CreateFromBlob(pBlob, true, CP_UTF8, &internalEncoding);
+    hr = InternalDxcBlobEncoding::CreateFromBlob(pBlob, DxcGetThreadMallocNoRef(), true, CP_UTF8, &internalEncoding);
     if (SUCCEEDED(hr)) {
       *pBlobEncoding = internalEncoding;
     }
@@ -499,7 +604,7 @@ HRESULT DxcGetBlobAsUtf8(IDxcBlob *pBlob, IDxcBlobEncoding **pBlobEncoding) {
 
   // Any UTF-16 output must be converted to UTF-16 first, then
   // back to the target code page.
-  CComHeapPtr<WCHAR> utf16NewCopy;
+  CDxcMallocHeapPtr<WCHAR> utf16NewCopy(DxcGetThreadMallocNoRef());
   wchar_t* utf16Chars = nullptr;
   UINT32 utf16CharCount;
   if (codePage == CP_UTF16) {
@@ -516,7 +621,7 @@ HRESULT DxcGetBlobAsUtf8(IDxcBlob *pBlob, IDxcBlobEncoding **pBlobEncoding) {
   }
 
   const UINT32 targetCodePage = CP_UTF8;
-  CComHeapPtr<char> finalNewCopy;
+  CDxcTMHeapPtr<char> finalNewCopy;
   int numToConvertFinal = WideCharToMultiByte(
     targetCodePage, 0, utf16Chars, utf16CharCount,
     finalNewCopy, 0, NULL, NULL);
@@ -537,7 +642,8 @@ HRESULT DxcGetBlobAsUtf8(IDxcBlob *pBlob, IDxcBlobEncoding **pBlobEncoding) {
   ((LPSTR)finalNewCopy)[numActuallyConvertedFinal] = '\0';
 
   InternalDxcBlobEncoding* internalEncoding;
-  hr = InternalDxcBlobEncoding::CreateFromHeap(finalNewCopy.m_pData,
+  hr = InternalDxcBlobEncoding::CreateFromMalloc(finalNewCopy.m_pData,
+    DxcGetThreadMallocNoRef(),
     numActuallyConvertedFinal, true, targetCodePage, &internalEncoding);
   if (SUCCEEDED(hr)) {
     *pBlobEncoding = internalEncoding;
@@ -573,13 +679,14 @@ DxcGetBlobAsUtf8NullTerm(_In_ IDxcBlob *pBlob,
       }
       
       // We have a non-null-terminated UTF-8 stream. Copy to a new location.
-      CComHeapPtr<char> pCopy;
+      CDxcTMHeapPtr<char> pCopy;
       if (!pCopy.Allocate(blobSize + 1))
         return E_OUTOFMEMORY;
       memcpy(pCopy.m_pData, pChars, blobSize);
       pCopy.m_pData[blobSize] = '\0';
-      IFR(DxcCreateBlobWithEncodingOnHeap(pCopy.m_pData, blobSize + 1, CP_UTF8,
-                                          ppBlobEncoding));
+      IFR(DxcCreateBlobWithEncodingOnMalloc(
+          pCopy.m_pData, DxcGetThreadMallocNoRef(), blobSize + 1, CP_UTF8,
+          ppBlobEncoding));
       pCopy.Detach();
       return S_OK;
     }
@@ -593,7 +700,7 @@ DxcGetBlobAsUtf8NullTerm(_In_ IDxcBlob *pBlob,
 }
 
 _Use_decl_annotations_
-HRESULT DxcGetBlobAsUtf16(IDxcBlob *pBlob, IDxcBlobEncoding **pBlobEncoding) {
+HRESULT DxcGetBlobAsUtf16(IDxcBlob *pBlob, IMalloc *pMalloc, IDxcBlobEncoding **pBlobEncoding) {
   *pBlobEncoding = nullptr;
 
   HRESULT hr;
@@ -624,7 +731,7 @@ HRESULT DxcGetBlobAsUtf16(IDxcBlob *pBlob, IDxcBlobEncoding **pBlobEncoding) {
   // Reuse the underlying blob but create an object with the encoding known.
   if (codePage == CP_UTF16) {
     InternalDxcBlobEncoding* internalEncoding;
-    hr = InternalDxcBlobEncoding::CreateFromBlob(pBlob, true, CP_UTF16, &internalEncoding);
+    hr = InternalDxcBlobEncoding::CreateFromBlob(pBlob, pMalloc, true, CP_UTF16, &internalEncoding);
     if (SUCCEEDED(hr)) {
       *pBlobEncoding = internalEncoding;
     }
@@ -632,7 +739,7 @@ HRESULT DxcGetBlobAsUtf16(IDxcBlob *pBlob, IDxcBlobEncoding **pBlobEncoding) {
   }
 
   // Convert and create a blob that owns the encoding.
-  CComHeapPtr<WCHAR> utf16NewCopy;
+  CDxcMallocHeapPtr<WCHAR> utf16NewCopy(pMalloc);
   UINT32 utf16CharCount;
   hr = CodePageBufferToUtf16(codePage, pBlob->GetBufferPointer(), blobLen,
                              utf16NewCopy, &utf16CharCount);
@@ -641,8 +748,9 @@ HRESULT DxcGetBlobAsUtf16(IDxcBlob *pBlob, IDxcBlobEncoding **pBlobEncoding) {
   }
 
   InternalDxcBlobEncoding* internalEncoding;
-  hr = InternalDxcBlobEncoding::CreateFromHeap(utf16NewCopy.m_pData,
-    utf16CharCount * sizeof(WCHAR), true, CP_UTF16, &internalEncoding);
+  hr = InternalDxcBlobEncoding::CreateFromMalloc(
+      utf16NewCopy.m_pData, pMalloc,
+      utf16CharCount * sizeof(WCHAR), true, CP_UTF16, &internalEncoding);
   if (SUCCEEDED(hr)) {
     *pBlobEncoding = internalEncoding;
     utf16NewCopy.Detach();
@@ -659,23 +767,31 @@ bool IsBlobNullOrEmpty(_In_opt_ IDxcBlob *pBlob) throw() {
 
 class MemoryStream : public AbstractMemoryStream, public IDxcBlob {
 private:
-  DXC_MICROCOM_REF_FIELD(m_dwRef)
-  CComPtr<IMalloc> m_pMalloc;
-  LPBYTE m_pMemory;
-  ULONG m_offset;
-  ULONG m_size;
-  ULONG m_allocSize;
+  DXC_MICROCOM_TM_REF_FIELDS()
+  LPBYTE m_pMemory = nullptr;
+  ULONG m_offset = 0;
+  ULONG m_size = 0;
+  ULONG m_allocSize = 0;
 public:
-  DXC_MICROCOM_ADDREF_RELEASE_IMPL(m_dwRef)
+  DXC_MICROCOM_ADDREF_IMPL(m_dwRef)
+  ULONG STDMETHODCALLTYPE Release() {
+    // Because memory streams are also used by tests and utilities,
+    // we avoid using TLS.
+    ULONG result = InterlockedDecrement(&m_dwRef); \
+    if (result == 0) {
+      CComPtr<IMalloc> pTmp(m_pMalloc);
+      this->~MemoryStream();
+      pTmp->Free(this);
+    }
+    return result;
+  }
+
+  DXC_MICROCOM_TM_CTOR(MemoryStream)
 
   HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid, void **ppvObject) {
     return DoBasicQueryInterface<IStream, ISequentialStream, IDxcBlob>(this, iid, ppvObject);
   }
 
-  MemoryStream(_In_ IMalloc *pMalloc)
-    : m_dwRef(0), m_pMalloc(pMalloc), m_pMemory(nullptr), m_offset(0),
-    m_size(0), m_allocSize(0) {}
-
   ~MemoryStream() {
     Reset();
   }
@@ -866,19 +982,22 @@ public:
 
 class ReadOnlyBlobStream : public IStream {
 private:
-  DXC_MICROCOM_REF_FIELD(m_dwRef)
+  DXC_MICROCOM_TM_REF_FIELDS()
   CComPtr<IDxcBlob> m_pSource;
   LPBYTE m_pMemory;
   ULONG m_offset;
   ULONG m_size;
 public:
-  DXC_MICROCOM_ADDREF_RELEASE_IMPL(m_dwRef)
+  DXC_MICROCOM_TM_ADDREF_RELEASE_IMPL()
+  DXC_MICROCOM_TM_CTOR(ReadOnlyBlobStream)
 
   HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid, void **ppvObject) {
     return DoBasicQueryInterface<IStream, ISequentialStream>(this, iid, ppvObject);
   }
 
-  ReadOnlyBlobStream(IDxcBlob *pSource) : m_pSource(pSource), m_offset(0), m_dwRef(0) {
+  void Init(IDxcBlob *pSource) {
+    m_pSource = pSource;
+    m_offset = 0;
     m_size = m_pSource->GetBufferSize();
     m_pMemory = (LPBYTE)m_pSource->GetBufferPointer();
   }
@@ -982,7 +1101,7 @@ HRESULT CreateMemoryStream(_In_ IMalloc *pMalloc, _COM_Outptr_ AbstractMemoryStr
     return E_POINTER;
   }
 
-  CComPtr<MemoryStream> stream = new (std::nothrow) MemoryStream(pMalloc);
+  CComPtr<MemoryStream> stream = MemoryStream::Alloc(pMalloc);
   *ppResult = stream.Detach();
   return (*ppResult == nullptr) ? E_OUTOFMEMORY : S_OK;
 }
@@ -992,7 +1111,10 @@ HRESULT CreateReadOnlyBlobStream(_In_ IDxcBlob *pSource, _COM_Outptr_ IStream**
     return E_POINTER;
   }
 
-  CComPtr<ReadOnlyBlobStream> stream = new (std::nothrow) ReadOnlyBlobStream(pSource);
+  CComPtr<ReadOnlyBlobStream> stream = ReadOnlyBlobStream::Alloc(DxcGetThreadMallocNoRef());
+  if (stream.p) {
+    stream->Init(pSource);
+  }
   *ppResult = stream.Detach();
   return (*ppResult == nullptr) ? E_OUTOFMEMORY : S_OK;
 }

+ 15 - 2
lib/DxcSupport/HLSLOptions.cpp

@@ -46,10 +46,23 @@ namespace {
 
 }
 
-static HlslOptTable g_HlslOptTable;
+static HlslOptTable *g_HlslOptTable;
+
+std::error_code hlsl::options::initHlslOptTable() {
+  DXASSERT(g_HlslOptTable == nullptr, "else double-init");
+  g_HlslOptTable = new (std::nothrow) HlslOptTable();
+  if (g_HlslOptTable == nullptr)
+    return std::error_code(E_OUTOFMEMORY, std::system_category());
+  return std::error_code();
+}
+
+void hlsl::options::cleanupHlslOptTable() {
+  delete g_HlslOptTable;
+  g_HlslOptTable = nullptr;
+}
 
 const OptTable * hlsl::options::getHlslOptTable() {
-  return &g_HlslOptTable;
+  return g_HlslOptTable;
 }
 
 void DxcDefines::push_back(llvm::StringRef value) {

+ 94 - 0
lib/DxcSupport/dxcmem.cpp

@@ -0,0 +1,94 @@
+///////////////////////////////////////////////////////////////////////////////
+//                                                                           //
+// dxcmem.cpp                                                                //
+// Copyright (C) Microsoft Corporation. All rights reserved.                 //
+// This file is distributed under the University of Illinois Open Source     //
+// License. See LICENSE.TXT for details.                                     //
+//                                                                           //
+// Provides support for a thread-local allocator.                            //
+//                                                                           //
+///////////////////////////////////////////////////////////////////////////////
+
+#include "dxc/Support/Global.h"
+#include <specstrings.h>
+
+#include "dxc/Support/WinIncludes.h"
+#include <memory>
+
+static DWORD g_ThreadMallocTlsIndex;
+static IMalloc *g_pDefaultMalloc;
+
+// Used by DllMain to set up and tear down per-thread tracking.
+HRESULT DxcInitThreadMalloc() throw();
+void DxcCleanupThreadMalloc() throw();
+
+// Used by APIs that are entry points to set up per-thread/invocation allocator.
+void DxcSetThreadMalloc(IMalloc *pMalloc) throw();
+void DxcSetThreadMallocOrDefault(IMalloc *pMalloc) throw(); 
+void DxcClearThreadMalloc() throw();
+
+// Used to retrieve the current invocation's allocator or perform an alloc/free/realloc.
+IMalloc *DxcGetThreadMallocNoRef() throw();
+_Ret_maybenull_ _Post_writable_byte_size_(nBytes) void *DxcThreadAlloc(size_t nBytes) throw();
+void DxcThreadFree(void *) throw();
+
+HRESULT DxcInitThreadMalloc() {
+  DXASSERT(g_ThreadMallocTlsIndex == 0, "else InitThreadMalloc already called");
+  DXASSERT(g_pDefaultMalloc == nullptr, "else InitThreadMalloc already called");
+
+  // We capture the default malloc early to avoid potential failures later on.
+  HRESULT hrMalloc = CoGetMalloc(1, &g_pDefaultMalloc);
+  if (FAILED(hrMalloc)) return hrMalloc;
+
+  g_ThreadMallocTlsIndex = TlsAlloc();
+  if (g_ThreadMallocTlsIndex == TLS_OUT_OF_INDEXES) {
+    g_ThreadMallocTlsIndex = 0;
+    g_pDefaultMalloc->Release();
+    g_pDefaultMalloc = nullptr;
+    return E_OUTOFMEMORY;
+  }
+
+  return S_OK;
+}
+
+void DxcCleanupThreadMalloc() {
+  if (g_ThreadMallocTlsIndex) {
+    TlsFree(g_ThreadMallocTlsIndex);
+    g_ThreadMallocTlsIndex = 0;
+    DXASSERT(g_pDefaultMalloc, "else DxcInitThreadMalloc didn't work/fail atomically");
+    g_pDefaultMalloc->Release();
+    g_pDefaultMalloc = nullptr;
+  }
+}
+
+IMalloc *DxcGetThreadMallocNoRef() {
+  DXASSERT(g_ThreadMallocTlsIndex != 0, "else prior to DxcInitThreadMalloc or after DxcCleanupThreadMalloc");
+  return reinterpret_cast<IMalloc *>(TlsGetValue(g_ThreadMallocTlsIndex));
+}
+void DxcClearThreadMalloc() {
+  DXASSERT(g_ThreadMallocTlsIndex != 0, "else prior to DxcInitThreadMalloc or after DxcCleanupThreadMalloc");
+  IMalloc *pMalloc = DxcGetThreadMallocNoRef();
+  DXVERIFY_NOMSG(TlsSetValue(g_ThreadMallocTlsIndex, nullptr));
+  pMalloc->Release();
+}
+void DxcSetThreadMalloc(IMalloc *pMalloc) {
+  DXASSERT(g_ThreadMallocTlsIndex != 0, "else prior to DxcInitThreadMalloc or after DxcCleanupThreadMalloc");
+  DXASSERT(DxcGetThreadMallocNoRef() == nullptr, "else nested allocation invoked");
+  DXVERIFY_NOMSG(TlsSetValue(g_ThreadMallocTlsIndex, pMalloc));
+  pMalloc->AddRef();
+}
+void DxcSetThreadMallocOrDefault(IMalloc *pMalloc) {
+  DxcSetThreadMalloc(pMalloc ? pMalloc : g_pDefaultMalloc);
+}
+IMalloc *DxcSwapThreadMalloc(IMalloc *pMalloc, IMalloc **ppPrior) {
+  DXASSERT(g_ThreadMallocTlsIndex != 0, "else prior to DxcInitThreadMalloc or after DxcCleanupThreadMalloc");
+  IMalloc *pPrior = DxcGetThreadMallocNoRef();
+  if (ppPrior) {
+    *ppPrior = pPrior;
+  }
+  DXVERIFY_NOMSG(TlsSetValue(g_ThreadMallocTlsIndex, pMalloc));
+  return pMalloc;
+}
+IMalloc *DxcSwapThreadMallocOrDefault(IMalloc *pMallocOrNull, IMalloc **ppPrior) {
+  return DxcSwapThreadMalloc(pMallocOrNull ? pMallocOrNull : g_pDefaultMalloc, ppPrior);
+}

+ 15 - 14
lib/HLSL/DxcOptimizer.cpp

@@ -383,19 +383,19 @@ static HRESULT Utf8ToUtf16CoTaskMalloc(LPCSTR pValue, LPWSTR *ppResult) {
 
 class DxcOptimizerPass : public IDxcOptimizerPass {
 private:
-  DXC_MICROCOM_REF_FIELD(m_dwRef)
+  DXC_MICROCOM_TM_REF_FIELDS()
   LPCSTR m_pOptionName;
   LPCSTR m_pDescription;
   ArrayRef<LPCSTR> m_pArgNames;
   ArrayRef<LPCSTR> m_pArgDescriptions;
 public:
-  DXC_MICROCOM_ADDREF_RELEASE_IMPL(m_dwRef)
+  DXC_MICROCOM_TM_ADDREF_RELEASE_IMPL()
+  DXC_MICROCOM_TM_CTOR(DxcOptimizerPass)
 
   HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid, void **ppvObject) {
     return DoBasicQueryInterface<IDxcOptimizerPass>(this, iid, ppvObject);
   }
 
-  DxcOptimizerPass() : m_dwRef(0) { }
   HRESULT Initialize(LPCSTR pOptionName, LPCSTR pDescription, ArrayRef<LPCSTR> pArgNames, ArrayRef<LPCSTR> pArgDescriptions) {
     DXASSERT(pArgNames.size() == pArgDescriptions.size(), "else lookup tables are out of alignment");
     m_pOptionName = pOptionName;
@@ -404,10 +404,10 @@ public:
     m_pArgDescriptions = pArgDescriptions;
     return S_OK;
   }
-  static HRESULT Create(LPCSTR pOptionName, LPCSTR pDescription, ArrayRef<LPCSTR> pArgNames, ArrayRef<LPCSTR> pArgDescriptions, IDxcOptimizerPass **ppResult) {
+  static HRESULT Create(IMalloc *pMalloc, LPCSTR pOptionName, LPCSTR pDescription, ArrayRef<LPCSTR> pArgNames, ArrayRef<LPCSTR> pArgDescriptions, IDxcOptimizerPass **ppResult) {
     CComPtr<DxcOptimizerPass> result;
     *ppResult = nullptr;
-    result = new (std::nothrow)DxcOptimizerPass();
+    result = DxcOptimizerPass::Alloc(pMalloc);
     IFROOM(result);
     IFR(result->Initialize(pOptionName, pDescription, pArgNames, pArgDescriptions));
     *ppResult = result.Detach();
@@ -441,17 +441,17 @@ public:
 
 class DxcOptimizer : public IDxcOptimizer {
 private:
-  DXC_MICROCOM_REF_FIELD(m_dwRef)
+  DXC_MICROCOM_TM_REF_FIELDS()
   PassRegistry *m_registry;
   std::vector<const PassInfo *> m_passes;
 public:
-  DXC_MICROCOM_ADDREF_RELEASE_IMPL(m_dwRef)
+  DXC_MICROCOM_TM_ADDREF_RELEASE_IMPL()
+  DXC_MICROCOM_TM_CTOR(DxcOptimizer)
 
   HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid, void **ppvObject) {
     return DoBasicQueryInterface<IDxcOptimizer>(this, iid, ppvObject);
   }
 
-  DxcOptimizer() : m_dwRef(0) { }
   HRESULT Initialize();
   const PassInfo *getPassByID(llvm::AnalysisID PassID);
   const PassInfo *getPassByName(const char *pName);
@@ -519,7 +519,8 @@ HRESULT STDMETHODCALLTYPE DxcOptimizer::GetAvailablePass(
   if (index >= m_passes.size())
     return E_INVALIDARG;
   return DxcOptimizerPass::Create(
-      m_passes[index]->getPassArgument(), m_passes[index]->getPassName(),
+      m_pMalloc, m_passes[index]->getPassArgument(),
+      m_passes[index]->getPassName(),
       GetPassArgNames(m_passes[index]->getPassArgument()),
       GetPassArgDescriptions(m_passes[index]->getPassArgument()), ppResult);
 }
@@ -535,6 +536,8 @@ HRESULT STDMETHODCALLTYPE DxcOptimizer::RunOptimizer(
   if (optionCount > 0 && ppOptions == nullptr)
     return E_POINTER;
 
+  DxcThreadMalloc TM(m_pMalloc);
+
   // Setup input buffer.
   // The ir parsing requires the buffer to be null terminated. We deal with
   // both source and bitcode input, so the input buffer may not be null
@@ -559,12 +562,10 @@ HRESULT STDMETHODCALLTYPE DxcOptimizer::RunOptimizer(
   legacy::PassManagerBase *pPassManager = &ModulePasses;
 
   try {
-    CComPtr<IMalloc> pMalloc;
     CComPtr<AbstractMemoryStream> pOutputStream;
     CComPtr<IDxcBlob> pOutputBlob;
 
-    IFT(CoGetMalloc(1, &pMalloc));
-    IFT(CreateMemoryStream(pMalloc, &pOutputStream));
+    IFT(CreateMemoryStream(m_pMalloc, &pOutputStream));
     IFT(pOutputStream.QueryInterface(&pOutputBlob));
 
     raw_stream_ostream outStream(pOutputStream.p);
@@ -745,7 +746,7 @@ HRESULT STDMETHODCALLTYPE DxcOptimizer::RunOptimizer(
     }
     if (ppOutputModule != nullptr) {
       CComPtr<AbstractMemoryStream> pProgramStream;
-      IFT(CreateMemoryStream(pMalloc, &pProgramStream));
+      IFT(CreateMemoryStream(m_pMalloc, &pProgramStream));
       {
         raw_stream_ostream outStream(pProgramStream.p);
         WriteBitcodeToFile(M.get(), outStream, true);
@@ -759,7 +760,7 @@ HRESULT STDMETHODCALLTYPE DxcOptimizer::RunOptimizer(
 }
 
 HRESULT CreateDxcOptimizer(_In_ REFIID riid, _Out_ LPVOID *ppv) {
-  CComPtr<DxcOptimizer> result = new (std::nothrow) DxcOptimizer();
+  CComPtr<DxcOptimizer> result = DxcOptimizer::Alloc(DxcGetThreadMallocNoRef());
   if (result == nullptr) {
     *ppv = nullptr;
     return E_OUTOFMEMORY;

+ 2 - 6
lib/HLSL/DxilContainerAssembler.cpp

@@ -854,9 +854,7 @@ void hlsl::SerializeDxilContainerForModule(DxilModule *pModule,
         [&](AbstractMemoryStream *pStream) { rootSigWriter.write(pStream); });
     pModule->StripRootSignatureFromMetadata();
     pInputProgramStream.Release();
-    CComPtr<IMalloc> pMalloc;
-    IFT(CoGetMalloc(1, &pMalloc));
-    IFT(CreateMemoryStream(pMalloc, &pInputProgramStream));
+    IFT(CreateMemoryStream(DxcGetThreadMallocNoRef(), &pInputProgramStream));
     raw_stream_ostream outStream(pInputProgramStream.p);
     WriteBitcodeToFile(pModule->GetModule(), outStream, true);
   }
@@ -877,9 +875,7 @@ void hlsl::SerializeDxilContainerForModule(DxilModule *pModule,
     llvm::StripDebugInfo(*pModule->GetModule());
     pModule->StripDebugRelatedCode();
 
-    CComPtr<IMalloc> pMalloc;
-    IFT(CoGetMalloc(1, &pMalloc));
-    IFT(CreateMemoryStream(pMalloc, &pProgramStream));
+    IFT(CreateMemoryStream(DxcGetThreadMallocNoRef(), &pProgramStream));
     raw_stream_ostream outStream(pProgramStream.p);
     WriteBitcodeToFile(pModule->GetModule(), outStream, true);
 

+ 14 - 11
lib/HLSL/DxilContainerReflection.cpp

@@ -46,18 +46,18 @@ using namespace hlsl;
 
 class DxilContainerReflection : public IDxcContainerReflection {
 private:
-  DXC_MICROCOM_REF_FIELD(m_dwRef)
+  DXC_MICROCOM_TM_REF_FIELDS()
   CComPtr<IDxcBlob> m_container;
-  const DxilContainerHeader *m_pHeader;
-  uint32_t m_headerLen;
+  const DxilContainerHeader *m_pHeader = nullptr;
+  uint32_t m_headerLen = 0;
   bool IsLoaded() const { return m_pHeader != nullptr; }
 public:
-  DXC_MICROCOM_ADDREF_RELEASE_IMPL(m_dwRef)
+  DXC_MICROCOM_TM_ADDREF_RELEASE_IMPL()
+  DXC_MICROCOM_TM_CTOR(DxilContainerReflection)
   HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid, void **ppvObject) {
     return DoBasicQueryInterface<IDxcContainerReflection>(this, iid, ppvObject);
   }
 
-  DxilContainerReflection() : m_dwRef(0), m_pHeader(nullptr), m_headerLen(0) { }
   __override HRESULT STDMETHODCALLTYPE Load(_In_ IDxcBlob *pContainer);
   __override HRESULT STDMETHODCALLTYPE GetPartCount(_Out_ UINT32 *pResult);
   __override HRESULT STDMETHODCALLTYPE GetPartKind(UINT32 idx, _Out_ UINT32 *pResult);
@@ -70,11 +70,11 @@ class CShaderReflectionConstantBuffer;
 class CShaderReflectionType;
 class DxilShaderReflection : public ID3D12ShaderReflection {
 private:
-  DXC_MICROCOM_REF_FIELD(m_dwRef)
+  DXC_MICROCOM_TM_REF_FIELDS()
   CComPtr<IDxcBlob> m_pContainer;
   LLVMContext Context;
   std::unique_ptr<Module> m_pModule; // Must come after LLVMContext, otherwise unique_ptr will over-delete.
-  DxilModule *m_pDxilModule;
+  DxilModule *m_pDxilModule = nullptr;
   std::vector<CShaderReflectionConstantBuffer>    m_CBs;
   std::vector<D3D12_SHADER_INPUT_BIND_DESC>       m_Resources;
   std::vector<D3D12_SIGNATURE_PARAMETER_DESC>     m_InputSignature;
@@ -103,7 +103,8 @@ public:
       api = DxilShaderReflection::PublicAPI::D3D11_47;
     return api;
   }
-  DXC_MICROCOM_ADDREF_RELEASE_IMPL(m_dwRef)
+  DXC_MICROCOM_TM_ADDREF_RELEASE_IMPL()
+  DXC_MICROCOM_TM_CTOR(DxilShaderReflection)
   HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid, void **ppvObject) {
     HRESULT hr = DoBasicQueryInterface<ID3D12ShaderReflection>(this, iid, ppvObject);
     if (hr == E_NOINTERFACE) {
@@ -119,7 +120,6 @@ public:
     return hr;
   }
 
-  DxilShaderReflection() : m_dwRef(0), m_pDxilModule(nullptr) { }
   HRESULT Load(IDxcBlob *pBlob, const DxilPartHeader *pPart);
 
   // ID3D12ShaderReflection
@@ -216,6 +216,7 @@ HRESULT DxilContainerReflection::GetPartContent(UINT32 idx, _COM_Outptr_ IDxcBlo
   const char *pData = GetDxilPartData(pPart);
   uint32_t offset = (uint32_t)(pData - (char*)m_container->GetBufferPointer()); // Offset from the beginning.
   uint32_t length = pPart->PartSize;
+  DxcThreadMalloc TM(m_pMalloc);
   return DxcCreateBlobFromBlob(m_container, offset, length, ppResult);
 }
 
@@ -241,8 +242,9 @@ HRESULT DxilContainerReflection::GetPartReflection(UINT32 idx, REFIID iid, void
     return E_NOTIMPL;
   }
   
+  DxcThreadMalloc TM(m_pMalloc);
   HRESULT hr = S_OK;
-  CComPtr<DxilShaderReflection> pReflection = new (std::nothrow)DxilShaderReflection();
+  CComPtr<DxilShaderReflection> pReflection = DxilShaderReflection::Alloc(m_pMalloc);
   IFCOOM(pReflection.p);
   DxilShaderReflection::PublicAPI api = DxilShaderReflection::IIDToAPI(iid);
   pReflection->SetPublicAPI(api);
@@ -254,8 +256,9 @@ Cleanup:
 }
 
 void hlsl::CreateDxcContainerReflection(IDxcContainerReflection **ppResult) {
-  CComPtr<DxilContainerReflection> pReflection = new DxilContainerReflection();
+  CComPtr<DxilContainerReflection> pReflection = DxilContainerReflection::Alloc(DxcGetThreadMallocNoRef());
   *ppResult = pReflection.Detach();
+  if (*ppResult == nullptr) throw std::bad_alloc();
 }
 
 ///////////////////////////////////////////////////////////////////////////////

+ 3 - 1
lib/HLSL/DxilGenerationPass.cpp

@@ -103,6 +103,8 @@ void InitResource(const DxilResource *pSource, DxilResource *pDest) {
 }
 
 void InitDxilModuleFromHLModule(HLModule &H, DxilModule &M, DxilEntrySignature *pSig, bool HasDebugInfo) {
+  std::unique_ptr<DxilEntrySignature> pSigPtr(pSig);
+
   // Subsystems.
   unsigned ValMajor, ValMinor;
   H.GetValidatorVersion(ValMajor, ValMinor);
@@ -158,7 +160,7 @@ void InitDxilModuleFromHLModule(HLModule &H, DxilModule &M, DxilEntrySignature *
   }
 
   // Signatures.
-  M.ResetEntrySignature(pSig);
+  M.ResetEntrySignature(pSigPtr.release());
   M.ResetRootSignature(H.ReleaseRootSignature());
 
   // Shader properties.

+ 5 - 3
lib/HLSL/DxilRootSignature.cpp

@@ -95,7 +95,9 @@ void RootSignatureHandle::Deserialize() {
 void RootSignatureHandle::LoadSerialized(const uint8_t *pData,
                                          unsigned length) {
   DXASSERT_NOMSG(IsEmpty());
-  IFT(DxcCreateBlobOnHeapCopy(pData, length, &m_pSerialized));
+  IDxcBlobEncoding *pCreated;
+  IFT(DxcCreateBlobWithEncodingOnHeapCopy(pData, length, CP_UTF8, &pCreated));
+  m_pSerialized = pCreated;
 }
 
 //////////////////////////////////////////////////////////////////////////////
@@ -1345,11 +1347,11 @@ void SerializeRootSignatureTemplate(_In_ const T_ROOT_SIGNATURE_DESC* pRootSigna
   memcpy(pSS, pRS->pStaticSamplers, StaticSamplerSize);
 
   // Create the result blob.
-  CComHeapPtr<char> bytes;
+  CDxcMallocHeapPtr<char> bytes(DxcGetThreadMallocNoRef());
   CComPtr<IDxcBlob> pBlob;
   unsigned cb = Serializer.GetSize();
   DXASSERT_NOMSG((cb & 0x3) == 0);
-  IFTBOOL(bytes.AllocateBytes(cb), E_OUTOFMEMORY);
+  IFTBOOL(bytes.Allocate(cb), E_OUTOFMEMORY);
   IFT(Serializer.Compact(bytes.m_pData, cb));
   IFT(DxcCreateBlobOnHeap(bytes.m_pData, cb, ppBlob));
   bytes.Detach(); // Ownership transfered to ppBlob.

+ 2 - 6
lib/HLSL/DxilValidation.cpp

@@ -4189,10 +4189,8 @@ static void VerifyBlobPartMatches(_In_ ValidationContext &ValCtx,
     return;
   }
 
-  CComPtr<IMalloc> pMalloc;
-  IFT(CoGetMalloc(1, &pMalloc));
   CComPtr<AbstractMemoryStream> pOutputStream;
-  IFT(CreateMemoryStream(pMalloc, &pOutputStream));
+  IFT(CreateMemoryStream(DxcGetThreadMallocNoRef(), &pOutputStream));
   pOutputStream->Reserve(Size);
 
   pWriter->write(pOutputStream);
@@ -4504,10 +4502,8 @@ HRESULT ValidateDxilBitcode(
   if (!dxilModule.GetRootSignature().IsEmpty()) {
     unique_ptr<DxilPartWriter> pWriter(NewPSVWriter(dxilModule, 0));
     DXASSERT_NOMSG(pWriter->size());
-    CComPtr<IMalloc> pMalloc;
-    IFT(CoGetMalloc(1, &pMalloc));
     CComPtr<AbstractMemoryStream> pOutputStream;
-    IFT(CreateMemoryStream(pMalloc, &pOutputStream));
+    IFT(CreateMemoryStream(DxcGetThreadMallocNoRef(), &pOutputStream));
     pOutputStream->Reserve(pWriter->size());
     pWriter->write(pOutputStream);
     const DxilVersionedRootSignatureDesc* pDesc = dxilModule.GetRootSignature().GetDesc();

+ 17 - 0
lib/IR/DiagnosticInfo.cpp

@@ -52,6 +52,8 @@ struct PassRemarksOpt {
   };
 };
 
+#if 0
+// These should all be specific to a pipline, not global to the process.
 static PassRemarksOpt PassRemarksOptLoc;
 static PassRemarksOpt PassRemarksMissedOptLoc;
 static PassRemarksOpt PassRemarksAnalysisOptLoc;
@@ -85,6 +87,21 @@ PassRemarksAnalysis(
     cl::Hidden, cl::location(PassRemarksAnalysisOptLoc), cl::ValueRequired,
     cl::ZeroOrMore);
 }
+#else
+struct PassRemarksOptNull {
+  Regex *Pattern = nullptr;
+  void operator=(const std::string &Val) {
+  }
+};
+static PassRemarksOptNull PassRemarksOptLoc;
+static PassRemarksOptNull PassRemarksMissedOptLoc;
+static PassRemarksOptNull PassRemarksAnalysisOptLoc;
+
+static PassRemarksOptNull PassRemarks;
+static PassRemarksOptNull PassRemarksMissed;
+static PassRemarksOptNull PassRemarksAnalysis;
+}
+#endif
 
 int llvm::getNextAvailablePluginDiagnosticKind() {
   static std::atomic<int> PluginKindID(DK_FirstPluginKind);

+ 2 - 0
lib/IR/Dominators.cpp

@@ -35,9 +35,11 @@ static bool VerifyDomInfo = true;
 #else
 static bool VerifyDomInfo = false;
 #endif
+#if 0 // HLSL Change Starts - option pending
 static cl::opt<bool,true>
 VerifyDomInfoX("verify-dom-info", cl::location(VerifyDomInfo),
                cl::desc("Verify dominator info (time consuming)"));
+#endif // HLSL Change Ends
 
 bool BasicBlockEdge::isSingleEdge() const {
   const TerminatorInst *TI = Start->getTerminator();

+ 14 - 0
lib/IR/Function.cpp

@@ -383,26 +383,39 @@ static StringPool *GCNamePool;
 static ManagedStatic<sys::SmartRWMutex<true> > GCLock;
 
 bool Function::hasGC() const {
+#if 0 // HLSL Change
   sys::SmartScopedReader<true> Reader(*GCLock);
   return GCNames && GCNames->count(this);
+#else
+  return false;
+#endif // HLSL Change Ends
 }
 
 const char *Function::getGC() const {
+#if 0 // HLSL Change
   assert(hasGC() && "Function has no collector");
   sys::SmartScopedReader<true> Reader(*GCLock);
   return *(*GCNames)[this];
+#else
+  return nullptr;
+#endif // HLSL Change Ends
 }
 
 void Function::setGC(const char *Str) {
+#if 0 // HLSL Change Starts
   sys::SmartScopedWriter<true> Writer(*GCLock);
   if (!GCNamePool)
     GCNamePool = new StringPool();
   if (!GCNames)
     GCNames = new DenseMap<const Function*,PooledStringPtr>();
   (*GCNames)[this] = GCNamePool->intern(Str);
+#else
+  assert(false && "GC not supported");
+#endif // HLSL Change Ends
 }
 
 void Function::clearGC() {
+#if 0 // HLSL Change Starts
   sys::SmartScopedWriter<true> Writer(*GCLock);
   if (GCNames) {
     GCNames->erase(this);
@@ -415,6 +428,7 @@ void Function::clearGC() {
       }
     }
   }
+#endif // HLSL Change Ends
 }
 
 /// copyAttributesFrom - copy all additional attributes (those not needed to

+ 2 - 2
lib/IR/LLVMContextImpl.cpp

@@ -85,9 +85,9 @@ LLVMContextImpl::~LLVMContextImpl() {
 
   // Also drop references that come from the Value bridges.
   for (auto &Pair : ValuesAsMetadata)
-    Pair.second->dropUsers();
+    if (Pair.second) Pair.second->dropUsers(); // HLSL Change - if alloc failed, entry might not be populated
   for (auto &Pair : MetadataAsValues)
-    Pair.second->dropUse();
+    if (Pair.second) Pair.second->dropUse(); // HLSL Change - if alloc failed, entry might not be populated
 
   // Destroy MDNodes.
   for (MDNode *I : DistinctMDNodes)

+ 36 - 6
lib/IR/LegacyPassManager.cpp

@@ -46,6 +46,7 @@ enum PassDebugLevel {
 };
 }
 
+#if 0 // HLSL Change Starts - option pending
 static cl::opt<enum PassDebugLevel>
 PassDebugging("debug-pass", cl::Hidden,
                   cl::desc("Print PassManager debugging information"),
@@ -56,12 +57,16 @@ PassDebugging("debug-pass", cl::Hidden,
   clEnumVal(Executions, "print pass name before it is executed"),
   clEnumVal(Details   , "print pass details when it is executed"),
                              clEnumValEnd));
+#else
+static const PassDebugLevel PassDebugging  = PassDebugLevel::Disabled;
+#endif // HLSL Change Ends
 
 namespace {
 typedef llvm::cl::list<const llvm::PassInfo *, bool, PassNameParser>
 PassOptionList;
 }
 
+#if 0 // HLSL Change Starts - option pending
 // Print IR out before/after specified passes.
 static PassOptionList
 PrintBefore("print-before",
@@ -81,6 +86,10 @@ static cl::opt<bool>
 PrintAfterAll("print-after-all",
               llvm::cl::desc("Print IR after each pass"),
               cl::init(false));
+#else
+static const bool PrintBeforeAll = false;
+static const bool PrintAfterAll = false;
+#endif // HLSL Change Ends
 
 /// This is a helper to determine whether to print IR before or
 /// after a pass.
@@ -99,13 +108,13 @@ static bool ShouldPrintBeforeOrAfterPass(const PassInfo *PI,
 /// This is a utility to check whether a pass should have IR dumped
 /// before it.
 static bool ShouldPrintBeforePass(const PassInfo *PI) {
-  return PrintBeforeAll || ShouldPrintBeforeOrAfterPass(PI, PrintBefore);
+  return false; // HLSL Change - return PrintBeforeAll || ShouldPrintBeforeOrAfterPass(PI, PrintBefore);
 }
 
 /// This is a utility to check whether a pass should have IR dumped
 /// after it.
 static bool ShouldPrintAfterPass(const PassInfo *PI) {
-  return PrintAfterAll || ShouldPrintBeforeOrAfterPass(PI, PrintAfter);
+  return false; // HLSL Change -PrintAfterAll || ShouldPrintBeforeOrAfterPass(PI, PrintAfter);
 }
 
 /// isPassDebuggingExecutionsOrMore - Return true if -debug-pass=Executions
@@ -573,8 +582,10 @@ AnalysisUsage *PMTopLevelManager::findAnalysisUsage(Pass *P) {
     AnUsage = DMI->second;
   else {
     AnUsage = new AnalysisUsage();
+    std::unique_ptr<AnalysisUsage> AnUsagePtr(AnUsage); // HLSL Change - unique_ptr until added
     P->getAnalysisUsage(*AnUsage);
     AnUsageMap[P] = AnUsage;
+    AnUsagePtr.release(); // HLSL Change
   }
   return AnUsage;
 }
@@ -587,6 +598,8 @@ void PMTopLevelManager::schedulePass(Pass *P) {
   // TODO : Allocate function manager for this pass, other wise required set
   // may be inserted into previous function manager
 
+  std::unique_ptr<Pass> PPtr(P); // take ownership locally until we pass it on
+
   // Give pass a chance to prepare the stage.
   P->preparePassManager(activeStack);
 
@@ -595,7 +608,7 @@ void PMTopLevelManager::schedulePass(Pass *P) {
   // available at this point.
   const PassInfo *PI = findAnalysisPassInfo(P->getPassID());
   if (PI && PI->isAnalysis() && findAnalysisPass(P->getPassID())) {
-    delete P;
+    // delete P; // HLSL Change - let PPtr take care of this
     return;
   }
 
@@ -658,9 +671,10 @@ void PMTopLevelManager::schedulePass(Pass *P) {
     // top level manager. Set up analysis resolver to connect them.
     PMDataManager *DM = getAsPMDataManager();
     AnalysisResolver *AR = new AnalysisResolver(*DM);
-    P->setResolver(AR);
+    P->setResolver(AR); // HLSL Comment - P takes ownership of AR here
     DM->initializeAnalysisImpl(P);
     addImmutablePass(IP);
+    PPtr.release(); // HLSL Change
     DM->recordAvailableAnalysis(IP);
     return;
   }
@@ -672,6 +686,7 @@ void PMTopLevelManager::schedulePass(Pass *P) {
   }
 
   // Add the requested pass to the best available pass manager.
+  PPtr.release(); // HLSL Change - assignPassManager takes ownership
   P->assignPassManager(activeStack, getTopLevelPassManagerType());
 
   if (PI && !PI->isAnalysis() && ShouldPrintAfterPass(PI)) {
@@ -971,10 +986,11 @@ void PMDataManager::freePass(Pass *P, StringRef Msg,
 /// Add pass P into the PassVector. Update
 /// AvailableAnalysis appropriately if ProcessAnalysis is true.
 void PMDataManager::add(Pass *P, bool ProcessAnalysis) {
+  std::unique_ptr<Pass> PPtr(P); // HLSL Change - take ownership of P
   // This manager is going to manage pass P. Set up analysis resolver
   // to connect them.
   AnalysisResolver *AR = new AnalysisResolver(*this);
-  P->setResolver(AR);
+  P->setResolver(AR); // HLSL Note: setResolver takes onwership of AR
 
   // If a FunctionPass F is the last user of ModulePass info M
   // then the F's manager, not F, records itself as a last user of M.
@@ -983,6 +999,7 @@ void PMDataManager::add(Pass *P, bool ProcessAnalysis) {
   if (!ProcessAnalysis) {
     // Add pass
     PassVector.push_back(P);
+    PPtr.release(); // HLSL Change
     return;
   }
 
@@ -1044,6 +1061,7 @@ void PMDataManager::add(Pass *P, bool ProcessAnalysis) {
 
   // Add pass
   PassVector.push_back(P);
+  PPtr.release(); // HLSL Change
 }
 
 
@@ -1370,7 +1388,11 @@ FunctionPassManager::FunctionPassManager(Module *m) : M(m) {
   // FPM is the top level manager.
   FPM->setTopLevelManager(FPM);
 
-  AnalysisResolver *AR = new AnalysisResolver(*FPM);
+  AnalysisResolver *AR = new (std::nothrow)AnalysisResolver(*FPM); // HLSL Change: nothrow and recover
+  if (!AR) {
+    delete FPM;
+    throw std::bad_alloc();
+  }
   FPM->setResolver(AR);
 }
 
@@ -1380,10 +1402,12 @@ FunctionPassManager::~FunctionPassManager() {
 
 void FunctionPassManager::add(Pass *P) {
   // HLSL Change Starts
+  std::unique_ptr<Pass> PPtr(P); // take ownership of P, even on failure paths
   if (TrackPassOS) {
     P->dumpConfig(*TrackPassOS);
     (*TrackPassOS) << '\n';
   }
+  PPtr.release();
   // HLSL Change Ends
   FPM->add(P);
 }
@@ -1726,10 +1750,12 @@ PassManager::~PassManager() {
 
 void PassManager::add(Pass *P) {
   // HLSL Change Starts
+  std::unique_ptr<Pass> PPtr(P); // take ownership of P, even on failure paths
   if (TrackPassOS) {
     P->dumpConfig(*TrackPassOS);
     (*TrackPassOS) << '\n';
   }
+  PPtr.release();
   // HLSL Change Ends
   PM->add(P);
 }
@@ -1744,9 +1770,11 @@ bool PassManager::run(Module &M) {
 // TimingInfo implementation
 
 bool llvm::TimePassesIsEnabled = false;
+#if 0 // HLSL Change Starts - option pending
 static cl::opt<bool,true>
 EnableTiming("time-passes", cl::location(TimePassesIsEnabled),
             cl::desc("Time each pass, printing elapsed time for each on exit"));
+#endif
 
 // createTheTimeInfo - This method either initializes the TheTimeInfo pointer to
 // a non-null value (if the -time-passes option is enabled) or it leaves it
@@ -1836,6 +1864,7 @@ void ModulePass::assignPassManager(PMStack &PMS,
 /// in the PM Stack and add self into that manager.
 void FunctionPass::assignPassManager(PMStack &PMS,
                                      PassManagerType PreferredType) {
+  std::unique_ptr<FunctionPass> thisPtr(this); // HLSL Change
 
   // Find Function Pass Manager
   while (!PMS.empty()) {
@@ -1870,6 +1899,7 @@ void FunctionPass::assignPassManager(PMStack &PMS,
   }
 
   // Assign FPP as the manager of this pass.
+  thisPtr.release();
   FPP->add(this);
 }
 

+ 14 - 6
lib/IR/Metadata.cpp

@@ -134,7 +134,7 @@ void ReplaceableMetadataImpl::addRef(void *Ref, OwnerTy Owner) {
 void ReplaceableMetadataImpl::dropRef(void *Ref) {
   bool WasErased = UseMap.erase(Ref);
   (void)WasErased;
-  assert(WasErased && "Expected to drop a reference");
+  // assert(WasErased && "Expected to drop a reference"); // HLSL Change - not while cleaning up OOM
 }
 
 void ReplaceableMetadataImpl::moveRef(void *Ref, void *New,
@@ -283,12 +283,12 @@ void ValueAsMetadata::handleDeletion(Value *V) {
 
   // Remove old entry from the map.
   ValueAsMetadata *MD = I->second;
-  assert(MD && "Expected valid metadata");
-  assert(MD->getValue() == V && "Expected valid mapping");
+  // assert(MD && "Expected valid metadata"); // HLSL Change - MD might be nullptr under OOM
+  // assert(MD->getValue() == V && "Expected valid mapping"); // HLSL Change - MD might be nullptr under OOM
   Store.erase(I);
 
   // Delete the metadata.
-  MD->replaceAllUsesWith(nullptr);
+  if (MD) MD->replaceAllUsesWith(nullptr); // HLSL Change - MD might be nullptr under OOM
   delete MD;
 }
 
@@ -710,8 +710,16 @@ MDTuple *MDTuple::getImpl(LLVMContext &Context, ArrayRef<Metadata *> MDs,
     assert(ShouldCreate && "Expected non-uniqued nodes to always be created");
   }
 
-  return storeImpl(new (MDs.size()) MDTuple(Context, Storage, Hash, MDs),
-                   Storage, Context.pImpl->MDTuples);
+  // HLSL Change - guard with try/catch
+  MDTuple *MDTuplePtr(new (MDs.size()) MDTuple(Context, Storage, Hash, MDs));
+  MDTuple *Result;
+  try {
+    Result = storeImpl(MDTuplePtr, Storage, Context.pImpl->MDTuples);
+  } catch (...) {
+    MDTuplePtr->deleteAsSubclass();
+    throw;
+  }
+  return Result;
 }
 
 void MDNode::deleteTemporary(MDNode *N) {

+ 5 - 2
lib/IR/Module.cpp

@@ -47,9 +47,12 @@ template class llvm::SymbolTableListTraits<GlobalAlias, Module>;
 
 Module::Module(StringRef MID, LLVMContext &C)
     : Context(C), Materializer(), ModuleID(MID), DL("") {
-  ValSymTab = new ValueSymbolTable();
-  NamedMDSymTab = new StringMap<NamedMDNode *>();
+  // HLSL Change - use unique_ptr to avoid leaks
+  std::unique_ptr<ValueSymbolTable> ValSymTabPtr(new ValueSymbolTable());
+  std::unique_ptr<StringMap<NamedMDNode *> > NamedMDSymTabPtr(new StringMap<NamedMDNode *>());
   Context.addModule(this);
+  ValSymTab = ValSymTabPtr.release();
+  NamedMDSymTab = NamedMDSymTabPtr.release();
 }
 
 Module::~Module() {

+ 9 - 0
lib/IR/User.cpp

@@ -136,6 +136,15 @@ void User::operator delete(void *Usr) {
   }
 }
 
+// HLSL Change Starts
+void User::operator delete(void *Usr, unsigned NumUserOperands) {
+  // Fun fact: during construction Obj->NumUserOperands is overwritten
+  Use *Storage = static_cast<Use *>(Usr) - NumUserOperands;
+  Use::zap(Storage, Storage + NumUserOperands, /* Delete */ false);
+  ::operator delete(Storage);
+}
+// HLSL Change Ends
+
 //===----------------------------------------------------------------------===//
 //                             Operator Class
 //===----------------------------------------------------------------------===//

+ 3 - 3
lib/IR/Value.cpp

@@ -181,8 +181,8 @@ void Value::setValueName(ValueName *VN) {
     return;
   }
 
-  HasName = true;
   Ctx.pImpl->ValueNames[this] = VN;
+  HasName = true; // HLSL Change - only set this to true after assignment
 }
 
 StringRef Value::getName() const {
@@ -608,9 +608,9 @@ void ValueHandleBase::AddToUseList() {
 }
 
 void ValueHandleBase::RemoveFromUseList() {
-  assert(V && V->HasValueHandle &&
+  assert(V && (std::current_exception() == nullptr || V->HasValueHandle) && // HLSL Change
          "Pointer doesn't have a use list!");
-
+  if (!V->HasValueHandle) return; // HLSL Change
   // Unlink this from its use list.
   ValueHandleBase **PrevPtr = getPrevPtr();
   assert(*PrevPtr == this && "List invariant broken");

+ 4 - 0
lib/IR/Verifier.cpp

@@ -78,7 +78,11 @@
 #include <cstdarg>
 using namespace llvm;
 
+#if 0 // HLSL Change Starts - option pending
 static cl::opt<bool> VerifyDebugInfo("verify-debug-info", cl::init(true));
+#else
+static const bool VerifyDebugInfo = true;
+#endif // HLSL Change Ends
 
 namespace {
 struct VerifierSupport {

+ 17 - 4
lib/Support/CommandLine.cpp

@@ -234,7 +234,7 @@ void Option::setArgStr(const char *S) {
 }
 
 // Initialise the general option category.
-OptionCategory llvm::cl::GeneralCategory("General options");
+OptionCategory *llvm::cl::GeneralCategory; // HLSL Change - GeneralCategory is now a pointer
 
 void OptionCategory::registerCategory() {
   GlobalParser->registerCategory(this);
@@ -817,7 +817,7 @@ void cl::ParseCommandLineOptions(int argc, const char *const *argv,
 void CommandLineParser::ParseCommandLineOptions(int argc,
                                                 const char *const *argv,
                                                 const char *Overview) {
-  assert(hasOptions() && "No options specified!");
+  // assert(hasOptions() && "No options specified!"); // HLSL Change - it's valid to have no options for the DLL build
 
   // Expand response files.
   SmallVector<const char *, 20> newArgv(argv, argv + argc);
@@ -1657,6 +1657,8 @@ static HelpPrinterWrapper WrappedNormalPrinter(UncategorizedNormalPrinter,
 static HelpPrinterWrapper WrappedHiddenPrinter(UncategorizedHiddenPrinter,
                                                CategorizedHiddenPrinter);
 
+#if 0 // HLSL Change Starts
+
 // Define a category for generic options that all tools should have.
 static cl::OptionCategory GenericCategory("Generic Options");
 
@@ -1713,6 +1715,13 @@ void HelpPrinterWrapper::operator=(bool Value) {
   } else
     UncategorizedPrinter = true; // Invoke uncategorized printer
 }
+#else
+static const bool PrintOptions = false;
+static const bool PrintAllOptions = false;
+
+void HelpPrinterWrapper::operator=(bool Value) {
+}
+#endif // HLSL Change Ends
 
 // Print the value of each option.
 void cl::PrintOptionValues() { GlobalParser->printOptionValues(); }
@@ -1795,10 +1804,14 @@ public:
 // Define the --version option that prints out the LLVM version for the tool
 static VersionPrinter VersionPrinterInstance;
 
+#if 0 // HLSL Change Starts
 static cl::opt<VersionPrinter, true, parser<bool>>
     VersOp("version", cl::desc("Display the version of this program"),
            cl::location(VersionPrinterInstance), cl::ValueDisallowed,
            cl::cat(GenericCategory));
+#else
+static const OptionCategory *GenericCategory;
+#endif // HLSL Change Ends
 
 // Utility function for printing the help message.
 void cl::PrintHelpMessage(bool Hidden, bool Categorized) {
@@ -1838,7 +1851,7 @@ StringMap<Option *> &cl::getRegisteredOptions() {
 void cl::HideUnrelatedOptions(cl::OptionCategory &Category) {
   for (auto &I : GlobalParser->OptionsMap) {
     if (I.second->Category != &Category &&
-        I.second->Category != &GenericCategory)
+        I.second->Category != GenericCategory) // HLSL Change - use pointer
       I.second->setHiddenFlag(cl::ReallyHidden);
   }
 }
@@ -1849,7 +1862,7 @@ void cl::HideUnrelatedOptions(ArrayRef<const cl::OptionCategory *> Categories) {
   for (auto &I : GlobalParser->OptionsMap) {
     if (std::find(CategoriesBegin, CategoriesEnd, I.second->Category) ==
             CategoriesEnd &&
-        I.second->Category != &GenericCategory)
+        I.second->Category != GenericCategory) // HLSL Change - use pointer
       I.second->setHiddenFlag(cl::ReallyHidden);
   }
 }

+ 5 - 4
lib/Support/ErrorHandling.cpp

@@ -45,18 +45,19 @@ using namespace llvm;
 thread_local static fatal_error_handler_t ErrorHandler = nullptr;
 thread_local static void *ErrorHandlerUserData = nullptr;
 
-static ManagedStatic<sys::Mutex> ErrorHandlerMutex;
+// HLSL Change - no mutex needed, handlers are thread_local
+//static ManagedStatic<sys::Mutex> ErrorHandlerMutex;
 
 void llvm::install_fatal_error_handler(fatal_error_handler_t handler,
                                        void *user_data) {
-  llvm::MutexGuard Lock(*ErrorHandlerMutex);
+  // llvm::MutexGuard Lock(*ErrorHandlerMutex); // HLSL Change - ErrorHandler and user data already thread-local
   assert(!ErrorHandler && "Error handler already registered!\n");
   ErrorHandler = handler;
   ErrorHandlerUserData = user_data;
 }
 
 void llvm::remove_fatal_error_handler() {
-  llvm::MutexGuard Lock(*ErrorHandlerMutex);
+  // llvm::MutexGuard Lock(*ErrorHandlerMutex); // HLSL Change - ErrorHandler and user data already thread-local
   ErrorHandler = nullptr;
   ErrorHandlerUserData = nullptr;
 }
@@ -79,7 +80,7 @@ void llvm::report_fatal_error(const Twine &Reason, bool GenCrashDiag) {
   {
     // Only acquire the mutex while reading the handler, so as not to invoke a
     // user-supplied callback under a lock.
-    llvm::MutexGuard Lock(*ErrorHandlerMutex);
+    // llvm::MutexGuard Lock(*ErrorHandlerMutex); // HLSL Change - ErrorHandler and user data already thread-local
     handler = ErrorHandler;
     handlerData = ErrorHandlerUserData;
   }

+ 4 - 0
lib/Support/GraphWriter.cpp

@@ -18,8 +18,12 @@
 #include "llvm/Support/Program.h"
 using namespace llvm;
 
+#if 0 // HLSL Change Starts - option pending
 static cl::opt<bool> ViewBackground("view-background", cl::Hidden,
   cl::desc("Execute graph viewer in the background. Creates tmp file litter."));
+#else
+static const bool ViewBackground = false;
+#endif // HLSL Change Ends
 
 std::string llvm::DOT::EscapeString(const std::string &Label) {
   std::string Str(Label);

+ 4 - 0
lib/Support/RandomNumberGenerator.cpp

@@ -26,9 +26,13 @@ using namespace llvm;
 // http://llvm.org/bugs/show_bug.cgi?id=19665
 //
 // Do not change to cl::opt<uint64_t> since this silently breaks argument parsing.
+#if 0 // HLSL Change Starts - option pending
 static cl::opt<unsigned long long>
 Seed("rng-seed", cl::value_desc("seed"),
      cl::desc("Seed for the random number generator"), cl::init(0));
+#else
+static const unsigned long long Seed = 0; // will go boom in the constructor, can't be set yet
+#endif // HLSL Change Ends
 
 RandomNumberGenerator::RandomNumberGenerator(StringRef Salt) {
   DEBUG(

+ 5 - 2
lib/Support/Statistic.cpp

@@ -39,11 +39,14 @@ namespace llvm { extern raw_ostream *CreateInfoOutputFile(); }
 /// -stats - Command line option to cause transformations to emit stats about
 /// what they did.
 ///
+#if 0 // HLSL Change Starts - option pending
 static cl::opt<bool>
 Enabled(
     "stats",
     cl::desc("Enable statistics output from program (available with Asserts)"));
-
+#else
+static const bool Enabled = false;
+#endif // HLSL Change Ends
 
 namespace {
 /// StatisticInfo - This class is used in a ManagedStatic so that it is created
@@ -90,7 +93,7 @@ StatisticInfo::~StatisticInfo() {
 }
 
 void llvm::EnableStatistics() {
-  Enabled.setValue(true);
+  //Enabled.setValue(true); // HLSL Change
 }
 
 bool llvm::AreStatisticsEnabled() {

+ 9 - 9
lib/Support/Timer.cpp

@@ -41,6 +41,7 @@ static std::string &getLibSupportInfoOutputFilename() {
 static ManagedStatic<sys::SmartMutex<true> > TimerLock;
 
 namespace {
+#if 0 // HLSL Change Starts - option pending
   static cl::opt<bool>
   TrackSpace("track-memory", cl::desc("Enable -time-passes memory "
                                       "tracking (this may be slow)"),
@@ -50,6 +51,10 @@ namespace {
   InfoOutputFilename("info-output-file", cl::value_desc("filename"),
                      cl::desc("File to append -stats and -timer output to"),
                    cl::Hidden, cl::location(getLibSupportInfoOutputFilename()));
+#else
+  static const bool TrackSpace = false;
+  static const char InfoOutputFilename[] = "";
+#endif // HLSL Change Ends
 }
 
 // CreateInfoOutputFile - Return a file stream to print our output on.
@@ -77,10 +82,10 @@ raw_ostream *llvm::CreateInfoOutputFile() {
 }
 
 #define DefaultTimerGroupName "Miscellaneous Ungrouped Timers"
-static TimerGroup DefaultTimerGroup(DefaultTimerGroupName); // HLSL Change - global init
+// static TimerGroup DefaultTimerGroup(DefaultTimerGroupName); // HLSL Change - global init
 static TimerGroup *getDefaultTimerGroup() {
 #if 1 // HLSL Change Starts - global with special clean-up and init
-  return &DefaultTimerGroup;
+  return nullptr; // rather than alloc-on-demand or &DefaultTimerGroup;
 #else
   TimerGroup *tmp = DefaultTimerGroup;
   sys::MemoryFence();
@@ -107,6 +112,7 @@ void Timer::init(StringRef N) {
   Name.assign(N.begin(), N.end());
   Started = false;
   TG = getDefaultTimerGroup();
+  if (!TG) return; // HLSL Change
   TG->addTimer(*this);
 }
 
@@ -277,12 +283,6 @@ TimerGroup::~TimerGroup() {
   // print the timing data.
   while (FirstTimer)
     removeTimer(*FirstTimer);
-
-  // HLSL Change Starts - don't bother cleaning up global
-  if (this == &DefaultTimerGroup) {
-    return;
-  }
-  // HLSL Change Ends
   
   // Remove the group from the TimerGroupList.
   sys::SmartScopedLock<true> L(*TimerLock);
@@ -346,7 +346,7 @@ void TimerGroup::PrintQueuedTimers(raw_ostream &OS) {
   // If this is not an collection of ungrouped times, print the total time.
   // Ungrouped timers don't really make sense to add up.  We still print the
   // TOTAL line to make the percentages make sense.
-  if (this != &DefaultTimerGroup)
+  if (this == getDefaultTimerGroup()) // HLSL Change
     OS << format("  Total Execution Time: %5.4f seconds (%5.4f wall clock)\n",
                  Total.getProcessTime(), Total.getWallTime());
   OS << '\n';

+ 4 - 3
lib/Support/Windows/Mutex.inc

@@ -24,15 +24,16 @@ using namespace sys;
 
 MutexImpl::MutexImpl(bool /*recursive*/)
 {
-  data_ = new CRITICAL_SECTION;
+  C_ASSERT(sizeof(data_) == sizeof(CRITICAL_SECTION));
+  // data_ = new CRITICAL_SECTION; // HLSL Change
   InitializeCriticalSection((LPCRITICAL_SECTION)data_);
 }
 
 MutexImpl::~MutexImpl()
 {
   DeleteCriticalSection((LPCRITICAL_SECTION)data_);
-  delete (LPCRITICAL_SECTION)data_;
-  data_ = 0;
+  // delete (LPCRITICAL_SECTION)data_; // HLSL Change
+  // data_ = 0; // HLSL Change
 }
 
 bool

+ 12 - 0
lib/Support/raw_ostream.cpp

@@ -746,7 +746,19 @@ raw_ostream &llvm::nulls() {
 //===----------------------------------------------------------------------===//
 
 raw_string_ostream::~raw_string_ostream() {
+#if 0 // HLSL Change Starts
   flush();
+#else
+  // C++ and exception in destructors don't play nice. The proper pattern
+  // here is to have the raw_string_ostream's owner flush before destruction
+  // and take appropriate action, like throwing or returning an error value.
+  try {
+    flush();
+  }
+  catch (const std::bad_alloc &) {
+    // Don't std::terminate()
+  }
+#endif // HLSL Change Ends
 }
 
 void raw_string_ostream::write_impl(const char *Ptr, size_t Size) {

+ 14 - 0
lib/Transforms/IPO/Inliner.cpp

@@ -46,6 +46,7 @@ STATISTIC(NumMergedAllocas, "Number of allocas merged together");
 // if those would be more profitable and blocked inline steps.
 STATISTIC(NumCallerCallersAnalyzed, "Number of caller-callers analyzed");
 
+#if 0 // HLSL Change Starts
 static cl::opt<int>
 InlineLimit("inline-threshold", cl::Hidden, cl::init(225), cl::ZeroOrMore,
         cl::desc("Control the amount of inlining to perform (default = 225)"));
@@ -60,6 +61,19 @@ HintThreshold("inlinehint-threshold", cl::Hidden, cl::init(325),
 static cl::opt<int>
 ColdThreshold("inlinecold-threshold", cl::Hidden, cl::init(225),
               cl::desc("Threshold for inlining functions with cold attribute"));
+#else
+struct NullOpt {
+  NullOpt(int val) : _val(val) {}
+  int _val;
+  int getNumOccurrences() const { return 0; }
+  operator int() const {
+    return _val;
+  }
+};
+static const NullOpt InlineLimit(225);
+static const NullOpt HintThreshold(325);
+static const NullOpt ColdThreshold(225);
+#endif // HLSL Change Ends
 
 // Threshold to use when optsize is specified (and there is no -inline-limit).
 const int OptSizeThreshold = 75;

+ 4 - 0
lib/Transforms/IPO/LowerBitSets.cpp

@@ -38,10 +38,14 @@ STATISTIC(NumByteArraysCreated, "Number of byte arrays created");
 STATISTIC(NumBitSetCallsLowered, "Number of bitset calls lowered");
 STATISTIC(NumBitSetDisjointSets, "Number of disjoint sets of bitsets");
 
+#if 0 // HLSL Change
 static cl::opt<bool> AvoidReuse(
     "lowerbitsets-avoid-reuse",
     cl::desc("Try to avoid reuse of byte array addresses using aliases"),
     cl::Hidden, cl::init(true));
+#else
+static bool AvoidReuse = true;
+#endif
 
 bool BitSetInfo::containsGlobalOffset(uint64_t Offset) const {
   if (Offset < ByteOffset)

+ 4 - 0
lib/Transforms/IPO/MergeFunctions.cpp

@@ -112,12 +112,16 @@ STATISTIC(NumThunksWritten, "Number of thunks generated");
 STATISTIC(NumAliasesWritten, "Number of aliases generated");
 STATISTIC(NumDoubleWeak, "Number of new functions created");
 
+#if 0 // HLSL Change
 static cl::opt<unsigned> NumFunctionsForSanityCheck(
     "mergefunc-sanity",
     cl::desc("How many functions in module could be used for "
              "MergeFunctions pass sanity check. "
              "'0' disables this check. Works only with '-debug' key."),
     cl::init(0), cl::Hidden);
+#else
+static const unsigned NumFunctionsForSanityCheck = 0;
+#endif
 
 namespace {
 

+ 22 - 7
lib/Transforms/IPO/PassManagerBuilder.cpp

@@ -55,12 +55,6 @@ static cl::opt<bool> ExtraVectorizerPasses(
     "extra-vectorizer-passes", cl::init(false), cl::Hidden,
     cl::desc("Run cleanup optimization passes after vectorization."));
 
-#else
-
-// Don't declare the 'false' counterparts - simply avoid altogether.
-
-#endif // HLSL Change - don't build vectorization passes
-
 static cl::opt<bool> UseNewSROA("use-new-sroa",
   cl::init(true), cl::Hidden,
   cl::desc("Enable the new, experimental SROA pass"));
@@ -99,6 +93,21 @@ static cl::opt<bool> EnableLoopDistribute(
     "enable-loop-distribute", cl::init(false), cl::Hidden,
     cl::desc("Enable the new, experimental LoopDistribution Pass"));
 
+#else
+
+// Don't declare the 'false' counterparts - simply avoid altogether.
+
+static const bool UseNewSROA = true;
+static const bool RunLoopRerolling = false;
+static const bool RunFloat2Int = true;
+static const bool RunLoadCombine = false;
+static const bool RunSLPAfterLoopVectorization = true;
+static const bool UseCFLAA = false;
+static const bool EnableMLSM = true;
+static const bool EnableLoopInterchange = false;
+static const bool EnableLoopDistribute = false;
+#endif // HLSL Change - don't build vectorization passes
+
 PassManagerBuilder::PassManagerBuilder() {
     OptLevel = 2;
     SizeLevel = 0;
@@ -127,15 +136,19 @@ PassManagerBuilder::~PassManagerBuilder() {
   delete Inliner;
 }
 
+#if 0 // HLSL Change Starts - no global extensions
 /// Set of global extensions, automatically added as part of the standard set.
 static ManagedStatic<SmallVector<std::pair<PassManagerBuilder::ExtensionPointTy,
    PassManagerBuilder::ExtensionFn>, 8> > GlobalExtensions;
+#endif // HLSL Change Ends
 
+#if 0 // HLSL Change Starts - no global extensions
 void PassManagerBuilder::addGlobalExtension(
     PassManagerBuilder::ExtensionPointTy Ty,
     PassManagerBuilder::ExtensionFn Fn) {
   GlobalExtensions->push_back(std::make_pair(Ty, Fn));
 }
+#endif // HLSL Change Ends
 
 void PassManagerBuilder::addExtension(ExtensionPointTy Ty, ExtensionFn Fn) {
   Extensions.push_back(std::make_pair(Ty, Fn));
@@ -143,12 +156,14 @@ void PassManagerBuilder::addExtension(ExtensionPointTy Ty, ExtensionFn Fn) {
 
 void PassManagerBuilder::addExtensionsToPM(ExtensionPointTy ETy,
                                            legacy::PassManagerBase &PM) const {
+#if 0 // HLSL Change Starts - no global extensions
   for (unsigned i = 0, e = GlobalExtensions->size(); i != e; ++i)
     if ((*GlobalExtensions)[i].first == ETy)
       (*GlobalExtensions)[i].second(*this, PM);
   for (unsigned i = 0, e = Extensions.size(); i != e; ++i)
     if (Extensions[i].first == ETy)
       Extensions[i].second(*this, PM);
+#endif // HLSL Change Ends
 }
 
 void PassManagerBuilder::addInitialAliasAnalysisPasses(
@@ -268,7 +283,7 @@ void PassManagerBuilder::populateModulePassManager(
     // builds. The function merging pass is 
     if (MergeFunctions)
       MPM.add(createMergeFunctionsPass());
-    else if (!GlobalExtensions->empty() || !Extensions.empty())
+    else if (!Extensions.empty()) // HLSL Change - GlobalExtensions not considered
       MPM.add(createBarrierNoopPass());
 
     addExtensionsToPM(EP_EnabledOnOptLevel0, MPM);

+ 5 - 0
lib/Transforms/Scalar/Float2Int.cpp

@@ -44,12 +44,17 @@ using namespace llvm;
 // as non-transformable. If we see an instruction that converts from the 
 // integer domain to FP domain (uitofp,sitofp), we terminate our walk.
 
+#if 0 // HLSL Change Starts - option pending
 /// The largest integer type worth dealing with.
 static cl::opt<unsigned>
 MaxIntegerBW("float2int-max-integer-bw", cl::init(64), cl::Hidden,
              cl::desc("Max integer bitwidth to consider in float2int"
                       "(default=64)"));
 
+#else
+static const unsigned MaxIntegerBW = 64;
+#endif // HLSL Change Ends
+
 namespace {
   struct Float2Int : public FunctionPass {
     static char ID; // Pass identification, replacement for typeid

+ 6 - 0
lib/Transforms/Scalar/GVN.cpp

@@ -65,6 +65,7 @@ STATISTIC(NumGVNSimpl,  "Number of instructions simplified");
 STATISTIC(NumGVNEqProp, "Number of equalities propagated");
 STATISTIC(NumPRELoad,   "Number of loads PRE'd");
 
+#if 0 // HLSL Change Starts - option pending
 static cl::opt<bool> EnablePRE("enable-pre",
                                cl::init(true), cl::Hidden);
 static cl::opt<bool> EnableLoadPRE("enable-load-pre", cl::init(true));
@@ -73,6 +74,11 @@ static cl::opt<bool> EnableLoadPRE("enable-load-pre", cl::init(true));
 static cl::opt<uint32_t>
 MaxRecurseDepth("max-recurse-depth", cl::Hidden, cl::init(1000), cl::ZeroOrMore,
                 cl::desc("Max recurse depth (default = 1000)"));
+#else
+static const bool EnablePRE = true;
+static const bool EnableLoadPRE = true;
+static const uint32_t MaxRecurseDepth = 1000;
+#endif // HLSL Change Ends
 
 //===----------------------------------------------------------------------===//
 //                         ValueTable Class

+ 8 - 0
lib/Transforms/Scalar/IndVarSimplify.cpp

@@ -59,6 +59,8 @@ STATISTIC(NumLFTR        , "Number of loop exit tests replaced");
 STATISTIC(NumElimExt     , "Number of IV sign/zero extends eliminated");
 STATISTIC(NumElimIV      , "Number of congruent IVs eliminated");
 
+#if 0 // HLSL Change Starts - option pending
+
 // Trip count verification can be enabled by default under NDEBUG if we
 // implement a strong expression equivalence checker in SCEV. Until then, we
 // use the verify-indvars flag, which may assert in some cases.
@@ -80,6 +82,12 @@ static cl::opt<ReplaceExitVal> ReplaceExitValue(
                clEnumValN(AlwaysRepl, "always",
                           "always replace exit value whenever possible"),
                clEnumValEnd));
+#else
+static const bool VerifyIndvars = false;
+static const bool ReduceLiveIVs = false;
+enum ReplaceExitVal { NeverRepl, OnlyCheapRepl, AlwaysRepl };
+static const ReplaceExitVal ReplaceExitValue = OnlyCheapRepl;
+#endif
 
 namespace {
 struct RewritePhi;

+ 4 - 0
lib/Transforms/Scalar/JumpThreading.cpp

@@ -44,10 +44,14 @@ STATISTIC(NumThreads, "Number of jumps threaded");
 STATISTIC(NumFolds,   "Number of terminators folded");
 STATISTIC(NumDupes,   "Number of branch blocks duplicated to eliminate phi");
 
+#if 0 // HLSL Change Starts - option pending
 static cl::opt<unsigned>
 BBDuplicateThreshold("jump-threading-threshold",
           cl::desc("Max block size to duplicate for jump threading"),
           cl::init(6), cl::Hidden);
+#else
+static const unsigned BBDuplicateThreshold = 6;
+#endif // HLSL Change Ends
 
 namespace {
   // These are at global scope so static functions can use them too.

+ 4 - 0
lib/Transforms/Scalar/LICM.cpp

@@ -67,9 +67,13 @@ STATISTIC(NumMovedLoads, "Number of load insts hoisted or sunk");
 STATISTIC(NumMovedCalls, "Number of call insts hoisted or sunk");
 STATISTIC(NumPromoted  , "Number of memory locations promoted to registers");
 
+#if 0 // HLSL Change Starts - option pending
 static cl::opt<bool>
 DisablePromotion("disable-licm-promotion", cl::Hidden,
                  cl::desc("Disable memory promotion in LICM pass"));
+#else
+static bool DisablePromotion = false;
+#endif // HLSL Change Ends
 
 static bool inSubLoop(BasicBlock *BB, Loop *CurLoop, LoopInfo *LI);
 static bool isNotUsedInLoop(const Instruction &I, const Loop *CurLoop);

+ 5 - 0
lib/Transforms/Scalar/LoopDistribute.cpp

@@ -42,6 +42,7 @@
 
 using namespace llvm;
 
+#if 0 // HLSL Change Starts - option pending
 static cl::opt<bool>
     LDistVerify("loop-distribute-verify", cl::Hidden,
                 cl::desc("Turn on DominatorTree and LoopInfo verification "
@@ -53,6 +54,10 @@ static cl::opt<bool> DistributeNonIfConvertible(
     cl::desc("Whether to distribute into a loop that may not be "
              "if-convertible by the loop vectorizer"),
     cl::init(false));
+#else
+static const bool LDistVerify = false;
+static const bool DistributeNonIfConvertible = false;
+#endif
 
 STATISTIC(NumLoopsDistributed, "Number of loops distributed");
 

+ 5 - 0
lib/Transforms/Scalar/LoopRerollPass.cpp

@@ -41,6 +41,7 @@ using namespace llvm;
 
 STATISTIC(NumRerolledLoops, "Number of rerolled loops");
 
+#if 0 // HLSL Change Starts - option pending
 static cl::opt<unsigned>
 MaxInc("max-reroll-increment", cl::init(2048), cl::Hidden,
   cl::desc("The maximum increment for loop rerolling"));
@@ -50,6 +51,10 @@ NumToleratedFailedMatches("reroll-num-tolerated-failed-matches", cl::init(400),
                           cl::Hidden,
                           cl::desc("The maximum number of failures to tolerate"
                                    " during fuzzy matching. (default: 400)"));
+#else
+static const unsigned MaxInc = 2048;
+static const unsigned NumToleratedFailedMatches = 400;
+#endif // HLSL Change Ends
 
 // This loop re-rolling transformation aims to transform loops like this:
 //

+ 4 - 0
lib/Transforms/Scalar/LoopRotation.cpp

@@ -36,9 +36,13 @@ using namespace llvm;
 
 #define DEBUG_TYPE "loop-rotate"
 
+#if 0 // HLSL Change Starts - option pending
 static cl::opt<unsigned>
 DefaultRotationThreshold("rotation-max-header-size", cl::init(16), cl::Hidden,
        cl::desc("The default maximum header size for automatic loop rotation"));
+#else
+static const unsigned DefaultRotationThreshold = 16;
+#endif // HLSL Change Ends
 
 STATISTIC(NumRotated, "Number of loops rotated");
 namespace {

+ 20 - 0
lib/Transforms/Scalar/LoopUnrollPass.cpp

@@ -37,6 +37,7 @@ using namespace llvm;
 
 #define DEBUG_TYPE "loop-unroll"
 
+#if 0 // HLSL Change Starts - option pending
 static cl::opt<unsigned>
     UnrollThreshold("unroll-threshold", cl::init(150), cl::Hidden,
                     cl::desc("The baseline cost threshold for loop unrolling"));
@@ -75,6 +76,25 @@ static cl::opt<unsigned>
 PragmaUnrollThreshold("pragma-unroll-threshold", cl::init(16 * 1024), cl::Hidden,
   cl::desc("Unrolled size limit for loops with an unroll(full) or "
            "unroll_count pragma."));
+#else
+template <typename T>
+struct NullOpt {
+  NullOpt(T val) : _val(val) {}
+  T _val;
+  unsigned getNumOccurrences() const { return 0; }
+  operator T() const {
+    return _val;
+  }
+};
+static const NullOpt<unsigned> UnrollThreshold = 150;
+static const NullOpt<unsigned> UnrollPercentDynamicCostSavedThreshold = 20;
+static const NullOpt<unsigned> UnrollDynamicCostSavingsDiscount = 2000;
+static const NullOpt<unsigned> UnrollMaxIterationsCountToAnalyze = 0;
+static const NullOpt<unsigned> UnrollCount = 0;
+static const NullOpt<bool> UnrollAllowPartial = false;
+static const NullOpt<bool> UnrollRuntime = false;
+static const NullOpt<unsigned> PragmaUnrollThreshold = 16 * 1024;
+#endif // HLSL Change Ends
 
 namespace {
   class LoopUnroll : public LoopPass {

+ 4 - 0
lib/Transforms/Scalar/LoopUnswitch.cpp

@@ -66,9 +66,13 @@ STATISTIC(TotalInsts,  "Total number of instructions analyzed");
 
 // The specific value of 100 here was chosen based only on intuition and a
 // few specific examples.
+#if 0 // HLSL Change Starts - option pending
 static cl::opt<unsigned>
 Threshold("loop-unswitch-threshold", cl::desc("Max loop size to unswitch"),
           cl::init(100), cl::Hidden);
+#else
+static const unsigned Threshold = 100;
+#endif
 
 namespace {
 

+ 11 - 1
lib/Transforms/Scalar/LowerExpectIntrinsic.cpp

@@ -34,12 +34,17 @@ using namespace llvm;
 STATISTIC(ExpectIntrinsicsHandled,
           "Number of 'expect' intrinsic instructions handled");
 
+#if 0 // HLSL Change Starts - option pending
 static cl::opt<uint32_t>
 LikelyBranchWeight("likely-branch-weight", cl::Hidden, cl::init(64),
                    cl::desc("Weight of the branch likely to be taken (default = 64)"));
 static cl::opt<uint32_t>
 UnlikelyBranchWeight("unlikely-branch-weight", cl::Hidden, cl::init(4),
                    cl::desc("Weight of the branch unlikely to be taken (default = 4)"));
+#else
+static const uint32_t LikelyBranchWeight = 64;
+static const uint32_t UnlikelyBranchWeight = 4;
+#endif
 
 static bool handleSwitchExpect(SwitchInst &SI) {
   CallInst *CI = dyn_cast<CallInst>(SI.getCondition());
@@ -51,13 +56,18 @@ static bool handleSwitchExpect(SwitchInst &SI) {
     return false;
 
   Value *ArgValue = CI->getArgOperand(0);
-  ConstantInt *ExpectedValue = dyn_cast<ConstantInt>(CI->getArgOperand(1));
+  ConstantInt *ExpectedValue = dyn_cast<ConstantInt>(CI->getArgOperand(1)); 
   if (!ExpectedValue)
     return false;
 
   SwitchInst::CaseIt Case = SI.findCaseValue(ExpectedValue);
   unsigned n = SI.getNumCases(); // +1 for default case.
+#if 0 // HLSL Change - help the compiler pick the right constructor overload
   SmallVector<uint32_t, 16> Weights(n + 1, UnlikelyBranchWeight);
+#else
+  SmallVector<uint32_t, 16> Weights;
+  Weights.assign(n + 1, UnlikelyBranchWeight);
+#endif
 
   if (Case == SI.case_default())
     Weights[0] = LikelyBranchWeight;

+ 6 - 0
lib/Transforms/Scalar/SROA.cpp

@@ -77,6 +77,7 @@ STATISTIC(NumLoadsSpeculated, "Number of loads speculated to allow promotion");
 STATISTIC(NumDeleted, "Number of instructions deleted");
 STATISTIC(NumVectorized, "Number of vectorized aggregates");
 
+#if 0 // HLSL Change Starts - option pending
 /// Hidden option to force the pass to not use DomTree and mem2reg, instead
 /// forming SSA values through the SSAUpdater infrastructure.
 static cl::opt<bool> ForceSSAUpdater("force-ssa-updater", cl::init(false),
@@ -91,6 +92,11 @@ static cl::opt<bool> SROARandomShuffleSlices("sroa-random-shuffle-slices",
 /// GEPs.
 static cl::opt<bool> SROAStrictInbounds("sroa-strict-inbounds", cl::init(false),
                                         cl::Hidden);
+#else
+static const bool ForceSSAUpdater = false;
+static const bool SROARandomShuffleSlices = false;
+static const bool SROAStrictInbounds = false;
+#endif // HLSL Change Ends
 
 namespace {
 /// \brief A custom IRBuilder inserter which prefixes all names if they are

+ 5 - 0
lib/Transforms/Scalar/SampleProfile.cpp

@@ -52,6 +52,7 @@ using namespace sampleprof;
 
 #define DEBUG_TYPE "sample-profile"
 
+#if 0 // HLSL Change Start
 // Command line option to specify the file to read samples from. This is
 // mainly used for debugging.
 static cl::opt<std::string> SampleProfileFile(
@@ -61,6 +62,10 @@ static cl::opt<unsigned> SampleProfileMaxPropagateIterations(
     "sample-profile-max-propagate-iterations", cl::init(100),
     cl::desc("Maximum number of iterations to go through when propagating "
              "sample block/edge weights through the CFG."));
+#else
+static const char SampleProfileFile[] = "";
+static const unsigned SampleProfileMaxPropagateIterations = 100;
+#endif // HLSL Change Ends
 
 namespace {
 typedef DenseMap<BasicBlock *, unsigned> BlockWeightMap;

+ 4 - 0
lib/Transforms/Scalar/SimplifyCFGPass.cpp

@@ -42,9 +42,13 @@ using namespace llvm;
 
 #define DEBUG_TYPE "simplifycfg"
 
+#if 0 // HLSL Change Starts
 static cl::opt<unsigned>
 UserBonusInstThreshold("bonus-inst-threshold", cl::Hidden, cl::init(1),
    cl::desc("Control the number of bonus instructions (default = 1)"));
+#else
+unsigned UserBonusInstThreshold = 1;
+#endif
 
 STATISTIC(NumSimpl, "Number of blocks simplified");
 

+ 5 - 0
lib/Transforms/Utils/InlineFunction.cpp

@@ -43,6 +43,7 @@
 #include <algorithm>
 using namespace llvm;
 
+#if 0 // HLSL Change Starts - option pending
 static cl::opt<bool>
 EnableNoAliasConversion("enable-noalias-to-md-conversion", cl::init(true),
   cl::Hidden,
@@ -52,6 +53,10 @@ static cl::opt<bool>
 PreserveAlignmentAssumptions("preserve-alignment-assumptions-during-inlining",
   cl::init(true), cl::Hidden,
   cl::desc("Convert align attributes to assumptions during inlining."));
+#else
+static const bool EnableNoAliasConversion = true;
+static const bool PreserveAlignmentAssumptions = true;
+#endif // HLSL Change Ends
 
 bool llvm::InlineFunction(CallInst *CI, InlineFunctionInfo &IFI,
                           bool InsertLifetime) {

+ 8 - 1
lib/Transforms/Utils/SimplifyCFG.cpp

@@ -57,6 +57,7 @@ using namespace PatternMatch;
 // a select, so the "clamp" idiom (of a min followed by a max) will be caught.
 // To catch this, we need to fold a compare and a select, hence '2' being the
 // minimum reasonable default.
+#if 0 // HLSL Change Starts - option pending
 static cl::opt<unsigned>
 PHINodeFoldingThreshold("phi-node-folding-threshold", cl::Hidden, cl::init(2),
    cl::desc("Control the amount of phi node folding to perform (default = 2)"));
@@ -67,11 +68,17 @@ DupRet("simplifycfg-dup-ret", cl::Hidden, cl::init(false),
 
 static cl::opt<bool>
 SinkCommon("simplifycfg-sink-common", cl::Hidden, cl::init(true),
-       cl::desc("Sink common instructions down to the end block"));
+           cl::desc("Sink common instructions down to the end block"));
 
 static cl::opt<bool> HoistCondStores(
     "simplifycfg-hoist-cond-stores", cl::Hidden, cl::init(true),
     cl::desc("Hoist conditional stores if an unconditional store precedes"));
+#else
+static const unsigned PHINodeFoldingThreshold = 2;
+static const bool DupRet = false;
+static const bool SinkCommon = true;
+static const bool HoistCondStores = true;
+#endif // HLSL Change Ends
 
 STATISTIC(NumBitMaps, "Number of switch instructions turned into bitmaps");
 STATISTIC(NumLinearMaps, "Number of switch instructions turned into linear mapping");

+ 6 - 2
lib/Transforms/Utils/SimplifyLibCalls.cpp

@@ -36,6 +36,7 @@
 using namespace llvm;
 using namespace PatternMatch;
 
+#if 0 // HLSL Change Starts - option pending
 static cl::opt<bool>
     ColdErrorCalls("error-reporting-is-cold", cl::init(true), cl::Hidden,
                    cl::desc("Treat error-reporting calls as cold"));
@@ -45,7 +46,10 @@ static cl::opt<bool>
                          cl::init(false),
                          cl::desc("Enable unsafe double to float "
                                   "shrinking for math lib calls"));
-
+#else
+static const bool ColdErrorCalls = true;
+static const bool EnableUnsafeFPShrink = false;
+#endif // HLSL Change Ends
 
 //===----------------------------------------------------------------------===//
 // Helper Functions
@@ -1966,7 +1970,7 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {
   bool isCallingConvC = CI->getCallingConv() == llvm::CallingConv::C;
 
   // Command-line parameter overrides function attribute.
-  if (EnableUnsafeFPShrink.getNumOccurrences() > 0)
+  if (false) // HLSL Change - EnableUnsafeFPShrink.getNumOccurrences() > 0)
     UnsafeFPShrink = EnableUnsafeFPShrink;
   else if (Callee->hasFnAttribute("unsafe-fp-math")) {
     // FIXME: This is the same problem as described in optimizeSqrt().

+ 3 - 1
lib/Transforms/Utils/SymbolRewriter.cpp

@@ -75,9 +75,11 @@
 using namespace llvm;
 using namespace SymbolRewriter;
 
+#if 0 // HLSL Change Starts - option pending
 static cl::list<std::string> RewriteMapFiles("rewrite-map-file",
                                              cl::desc("Symbol Rewrite Map"),
                                              cl::value_desc("filename"));
+#endif // HLSL Change Ends
 
 static void rewriteComdat(Module &M, GlobalObject *GO,
                           const std::string &Source,
@@ -532,7 +534,7 @@ bool RewriteSymbols::runOnModule(Module &M) {
 }
 
 void RewriteSymbols::loadAndParseMapFiles() {
-  const std::vector<std::string> MapFiles(RewriteMapFiles);
+  const std::vector<std::string> MapFiles; // HLSL Change - do not init from a global RewriteMapFiles
   SymbolRewriter::RewriteMapParser parser;
 
   for (const auto &MapFile : MapFiles)

+ 7 - 2
tools/clang/lib/AST/DeclarationName.cpp

@@ -346,8 +346,9 @@ void DeclarationName::dump() const {
 }
 
 DeclarationNameTable::DeclarationNameTable(const ASTContext &C) : Ctx(C) {
-  CXXSpecialNamesImpl = new llvm::FoldingSet<CXXSpecialName>;
-  CXXLiteralOperatorNames = new llvm::FoldingSet<CXXLiteralOperatorIdName>;
+  // HLSL Change Starts - use std::unique_ptr to avoid leaks
+  std::unique_ptr<llvm::FoldingSet<CXXSpecialName> > CXXSpecialNamesImplPtr(new llvm::FoldingSet<CXXSpecialName>());
+  std::unique_ptr<llvm::FoldingSet<CXXLiteralOperatorIdName> > CXXLiteralOperatorNamesPtr(new llvm::FoldingSet<CXXLiteralOperatorIdName>());
 
   // Initialize the overloaded operator names.
   CXXOperatorNames = new (Ctx) CXXOperatorIdName[NUM_OVERLOADED_OPERATORS];
@@ -356,6 +357,10 @@ DeclarationNameTable::DeclarationNameTable(const ASTContext &C) : Ctx(C) {
       = Op + DeclarationNameExtra::CXXConversionFunction;
     CXXOperatorNames[Op].FETokenInfo = nullptr;
   }
+
+  CXXSpecialNamesImpl = CXXSpecialNamesImplPtr.release();
+  CXXLiteralOperatorNames = CXXLiteralOperatorNamesPtr.release();
+  // HLSL Change Ends - use std::unique_ptr to avoid leaks
 }
 
 DeclarationNameTable::~DeclarationNameTable() {

+ 1 - 0
tools/clang/lib/Basic/SourceManager.cpp

@@ -112,6 +112,7 @@ llvm::MemoryBuffer *ContentCache::getBuffer(DiagnosticsEngine &Diag,
     StringRef FillStr("<<<MISSING SOURCE FILE>>>\n");
     Buffer.setPointer(MemoryBuffer::getNewUninitMemBuffer(
                           ContentsEntry->getSize(), "<invalid>").release());
+    if (Buffer.getPointer() == nullptr) throw std::bad_alloc(); // HLSL Change
     char *Ptr = const_cast<char*>(Buffer.getPointer()->getBufferStart());
     for (unsigned i = 0, e = ContentsEntry->getSize(); i != e; ++i)
       Ptr[i] = FillStr[i % FillStr.size()];

+ 14 - 0
tools/clang/lib/Basic/VirtualFileSystem.cpp

@@ -176,11 +176,25 @@ RealFileSystem::openFileForRead(const Twine &Name) {
   return std::move(Result);
 }
 
+#if 0 // HLSL Change Starts
+
 IntrusiveRefCntPtr<FileSystem> vfs::getRealFileSystem() {
   static IntrusiveRefCntPtr<FileSystem> FS = new RealFileSystem();
   return FS;
 }
 
+#else 
+
+static RealFileSystem g_RealFileSystem;
+
+IntrusiveRefCntPtr<FileSystem> vfs::getRealFileSystem() {
+  g_RealFileSystem.Retain(); // never let go - TODO: guard against refcount wraparound
+  IntrusiveRefCntPtr<FileSystem> Result(&g_RealFileSystem);
+  return Result;
+}
+
+#endif // HLSL Change Ends
+
 namespace {
 class RealFSDirIter : public clang::vfs::detail::DirIterImpl {
   std::string Path;

+ 8 - 0
tools/clang/lib/CodeGen/CodeGenAction.cpp

@@ -74,6 +74,14 @@ namespace clang {
       llvm::TimePassesIsEnabled = TimePasses;
     }
 
+    // HLSL Change Starts - avoid double free
+    ~BackendConsumer() {
+      if (TheModule.get() && Gen.get()) {
+        Gen->ReleaseModule();
+      }
+    }
+    // HLSL Change Ends - avoid double free
+
     std::unique_ptr<llvm::Module> takeModule() { return std::move(TheModule); }
     llvm::Module *takeLinkModule() { return LinkModule.release(); }
 

+ 21 - 3
tools/clang/lib/CodeGen/CodeGenModule.cpp

@@ -125,14 +125,22 @@ CodeGenModule::CodeGenModule(ASTContext &C, const HeaderSearchOptions &HSO,
     createOpenMPRuntime();
   if (LangOpts.CUDA)
     createCUDARuntime();
-  if (LangOpts.HLSL)         // HLSL Change 
-    createHLSLRuntime();     // HLSL Change
+  // HLSL Change Starts
+  std::unique_ptr<CGHLSLRuntime> RuntimePtr;
+  std::unique_ptr<CodeGenTBAA> TBAAPtr;
+  std::unique_ptr<CGDebugInfo> DebugInfoPtr;
+  if (LangOpts.HLSL) {
+    createHLSLRuntime();
+    RuntimePtr.reset(HLSLRuntime);
+  }
+  // HLSL Change Ends
 
   // Enable TBAA unless it's suppressed. ThreadSanitizer needs TBAA even at O0.
   if (LangOpts.Sanitize.has(SanitizerKind::Thread) ||
       (!CodeGenOpts.RelaxedAliasing && CodeGenOpts.OptimizationLevel > 0))
     TBAA = new CodeGenTBAA(Context, VMContext, CodeGenOpts, getLangOpts(),
                            getCXXABI().getMangleContext());
+  TBAAPtr.reset(TBAA); // HLSL Change
 
   // If debug info or coverage generation is enabled, create the CGDebugInfo
   // object.
@@ -140,12 +148,15 @@ CodeGenModule::CodeGenModule(ASTContext &C, const HeaderSearchOptions &HSO,
       CodeGenOpts.EmitGcovArcs ||
       CodeGenOpts.EmitGcovNotes)
     DebugInfo = new CGDebugInfo(*this);
+  DebugInfoPtr.reset(DebugInfo); // HLSL Change
 
   Block.GlobalUniqueCount = 0;
 
+#if 0 // HLSL Change Starts - no ARC support
   if (C.getLangOpts().ObjCAutoRefCount)
     ARCData = new ARCEntrypoints();
   RRData = new RREntrypoints();
+#endif // HLSL Change Ends - no ARC support
 
   if (!CodeGenOpts.InstrProfileInput.empty()) {
     auto ReaderOrErr =
@@ -163,6 +174,12 @@ CodeGenModule::CodeGenModule(ASTContext &C, const HeaderSearchOptions &HSO,
   // CoverageMappingModuleGen object.
   if (CodeGenOpts.CoverageMapping)
     CoverageMapping.reset(new CoverageMappingModuleGen(*this, *CoverageInfo));
+
+  // HLSL Change Starts - release acquired pointers
+  RuntimePtr.release();
+  TBAAPtr.release();
+  DebugInfoPtr.release();
+  // HLSL Change Ends
 }
 
 CodeGenModule::~CodeGenModule() {
@@ -171,7 +188,7 @@ CodeGenModule::~CodeGenModule() {
   delete OpenMPRuntime;
   delete CUDARuntime;
   delete HLSLRuntime;  // HLSL Change
-  delete TheTargetCodeGenInfo;
+  TheTargetCodeGenInfo.reset(nullptr); // HLSL Change
   delete TBAA;
   delete DebugInfo;
   delete ARCData;
@@ -1626,6 +1643,7 @@ CodeGenModule::GetOrCreateLLVMFunction(StringRef MangledName,
     IsIncompleteFunction = true;
   }
   
+  // HLSL Change: unique_ptr for F
   llvm::Function *F = llvm::Function::Create(FTy,
                                              llvm::Function::ExternalLinkage,
                                              MangledName, &getModule());

+ 1 - 1
tools/clang/lib/CodeGen/CodeGenModule.h

@@ -293,7 +293,7 @@ private:
 
   CodeGenTBAA *TBAA;
   
-  mutable const TargetCodeGenInfo *TheTargetCodeGenInfo;
+  mutable std::unique_ptr<TargetCodeGenInfo> TheTargetCodeGenInfo; // HLSL Change - unique_ptr rather than const *
   
   // This should not be moved earlier, since its initialization depends on some
   // of the previous reference members being already initialized and also checks

+ 6 - 3
tools/clang/lib/CodeGen/TargetInfo.cpp

@@ -7204,8 +7204,8 @@ const TargetCodeGenInfo &CodeGenModule::getTargetCodeGenInfo() {
   const llvm::Triple &Triple = getTarget().getTriple();
   switch (Triple.getArch()) {
   default:
-    return *(TheTargetCodeGenInfo = new DefaultTargetCodeGenInfo(Types));
-
+    TheTargetCodeGenInfo.reset(new DefaultTargetCodeGenInfo(Types)); break; // HLSL Change - reset
+#if 0 // HLSL Change Starts
   case llvm::Triple::le32:
     return *(TheTargetCodeGenInfo = new PNaClTargetCodeGenInfo(Types));
   case llvm::Triple::mips:
@@ -7334,10 +7334,13 @@ const TargetCodeGenInfo &CodeGenModule::getTargetCodeGenInfo() {
     return *(TheTargetCodeGenInfo = new SparcV9TargetCodeGenInfo(Types));
   case llvm::Triple::xcore:
     return *(TheTargetCodeGenInfo = new XCoreTargetCodeGenInfo(Types));
+#endif // HLSL Change Ends
   // HLSL Change Begins
   case llvm::Triple::dxil:
   case llvm::Triple::dxil64:
-    return *(TheTargetCodeGenInfo = new MSDXILTargetCodeGenInfo(Types));
+    TheTargetCodeGenInfo.reset(new MSDXILTargetCodeGenInfo(Types));
+    break;
   // HLSL Change Ends
   }
+  return *(TheTargetCodeGenInfo.get());
 }

+ 4 - 3
tools/clang/lib/Frontend/ASTUnit.cpp

@@ -196,11 +196,12 @@ void ASTUnit::clearFileLevelDecls() {
 }
 
 void ASTUnit::CleanTemporaryFiles() {
-  getOnDiskData(this).CleanTemporaryFiles();
+  // getOnDiskData(this).CleanTemporaryFiles(); // HLSL Change - no temporary files generated
 }
 
 void ASTUnit::addTemporaryFile(StringRef TempFile) {
-  getOnDiskData(this).TemporaryFiles.push_back(TempFile);
+  // getOnDiskData(this).TemporaryFiles.push_back(TempFile); // HLSL Change - no temporary files generated
+  assert("caller attempted to create a temporary file");
 }
 
 /// \brief After failing to build a precompiled preamble (due to
@@ -243,7 +244,7 @@ ASTUnit::~ASTUnit() {
   clearFileLevelDecls();
 
   // Clean up the temporary files and the preamble file.
-  removeOnDiskEntry(this);
+  // removeOnDiskEntry(this); // HLSL Change - no temporary/preamble files generated.
 
   // Free the buffers associated with remapped files. We are required to
   // perform this operation here because we explicitly request that the

+ 16 - 3
tools/clang/lib/Frontend/CompilerInstance.cpp

@@ -64,6 +64,8 @@ CompilerInstance::CompilerInstance(
 #endif // HLSL Change Ends - no support for modules
 
 CompilerInstance::~CompilerInstance() {
+  // TODO: harden output file cleanup so there are no additional allocations/exceptions
+  clearOutputFiles(/* EraseFiles */ false); // HLSL Change - might happen when destroying under exception
   assert(OutputFiles.empty() && "Still output files in flight?");
 }
 
@@ -310,14 +312,16 @@ void CompilerInstance::createPreprocessor(TranslationUnitKind TUKind) {
     PTHMgr = PTHManager::Create(PPOpts.TokenCache, getDiagnostics());
 
   // Create the Preprocessor.
-  HeaderSearch *HeaderInfo = new HeaderSearch(&getHeaderSearchOpts(),
+  std::unique_ptr<HeaderSearch> HeaderInfo ( // HLSL Change - make unique_ptr and free
+                             new HeaderSearch(&getHeaderSearchOpts(),
                                               getSourceManager(),
                                               getDiagnostics(),
                                               getLangOpts(),
-                                              &getTarget());
+                                              &getTarget()));
   PP = new Preprocessor(&getPreprocessorOpts(), getDiagnostics(), getLangOpts(),
-                        getSourceManager(), *HeaderInfo, *this, PTHMgr,
+                        getSourceManager(), *HeaderInfo.get(), *this, PTHMgr,
                         /*OwnsHeaderSearch=*/true, TUKind);
+  HeaderInfo.release(); // HLSL Change - Preprocessor has ownership at this point
   PP->Initialize(getTarget());
 
   // Note that this is different then passing PTHMgr to Preprocessor's ctor.
@@ -557,8 +561,16 @@ void CompilerInstance::addOutputFile(OutputFile &&OutFile) {
 }
 
 void CompilerInstance::clearOutputFiles(bool EraseFiles) {
+  bool errorsFound = false; // HLSL Change - track this, but try to clean up everything anyway
   for (OutputFile &OF : OutputFiles) {
     // Manually close the stream before we rename it.
+    // HLSL Change Starts - call explicitly to have it throw on error during regular flow (rather than .dtor)
+    if (OF.OS.get()) {
+      OF.OS->close();
+      errorsFound = errorsFound || OF.OS->has_error();
+      OF.OS->clear_error();
+    }
+    // HLSL Change Ends
     OF.OS.reset();
 
     if (!OF.TempFilename.empty()) {
@@ -584,6 +596,7 @@ void CompilerInstance::clearOutputFiles(bool EraseFiles) {
   }
   OutputFiles.clear();
   NonSeekStream.reset();
+  if (errorsFound) throw std::exception("errors when processing output"); // HLSL Change
 }
 
 raw_pwrite_stream *

+ 2 - 2
tools/clang/lib/Lex/PPLexerChange.cpp

@@ -109,11 +109,11 @@ bool Preprocessor::EnterSourceFile(FileID FID, const DirectoryLookup *CurDir,
 ///  and start lexing tokens from it instead of the current buffer.
 void Preprocessor::EnterSourceFileWithLexer(Lexer *TheLexer,
                                             const DirectoryLookup *CurDir) {
-
+  std::unique_ptr<Lexer> LexerGuard(TheLexer); // HLSL Change - guard
   // Add the current lexer to the include stack.
   if (CurPPLexer || CurTokenLexer)
     PushIncludeMacroStack();
-
+  LexerGuard.release(); // HLSL Change
   CurLexer.reset(TheLexer);
   CurPPLexer = TheLexer;
   CurDirLookup = CurDir;

+ 1 - 0
tools/clang/lib/Lex/Preprocessor.cpp

@@ -513,6 +513,7 @@ void Preprocessor::EnterMainSourceFile() {
   // Preprocess Predefines to populate the initial preprocessor state.
   std::unique_ptr<llvm::MemoryBuffer> SB =
     llvm::MemoryBuffer::getMemBufferCopy(Predefines, "<built-in>");
+  if (SB.get() == nullptr) throw std::bad_alloc(); // HLSL Change
   assert(SB && "Cannot create predefined source buffer");
   FileID FID = SourceMgr.createFileID(std::move(SB));
   assert(!FID.isInvalid() && "Could not create FileID for predefines?");

+ 0 - 3
tools/clang/test/HLSL/pix/removeDiscards.hlsl

@@ -1,10 +1,7 @@
 // RUN: %dxc -Emain -Tps_6_0 %s | %opt -S -hlsl-dxil-remove-discards | %FileCheck %s
 
 // Check that the discard within the if/then was removed:
-//     CHECK: if.then:                                          ; preds = %entry
 // CHECK-NOT:   call void @dx.op.discard(i32 82, i1 true)
-//     CHECK:   br label %if.end
-//     CHECK: if.end:
 
 struct RTOut
 {

+ 84 - 9
tools/clang/tools/dxc/dxc.cpp

@@ -62,6 +62,59 @@
 #include <algorithm>
 #include <unordered_map>
 
+struct NoSerializeHeapMalloc : public IMalloc {
+private:
+  HANDLE m_Handle;
+public:
+  void SetHandle(HANDLE Handle) { m_Handle = Handle; }
+  ULONG STDMETHODCALLTYPE AddRef() {
+    return 1;
+  }
+  ULONG STDMETHODCALLTYPE Release() {
+    return 1;
+  }
+  STDMETHODIMP QueryInterface(REFIID iid, void** ppvObject) {
+    return DoBasicQueryInterface<IMalloc>(this, iid, ppvObject);
+  }
+  virtual void *STDMETHODCALLTYPE Alloc(
+    _In_  SIZE_T cb) {
+    return HeapAlloc(m_Handle, 0, cb);
+  }
+
+  virtual void *STDMETHODCALLTYPE Realloc(
+    _In_opt_  void *pv,
+    _In_  SIZE_T cb)
+  {
+    return HeapReAlloc(m_Handle, 0, pv, cb);
+  }
+
+  virtual void STDMETHODCALLTYPE Free(
+    _In_opt_  void *pv)
+  {
+    HeapFree(m_Handle, 0, pv);
+  }
+
+
+  virtual SIZE_T STDMETHODCALLTYPE GetSize(
+    /* [annotation][in] */
+    _In_opt_ _Post_writable_byte_size_(return)  void *pv)
+  {
+    return HeapSize(m_Handle, 0, pv);
+  }
+
+  virtual int STDMETHODCALLTYPE DidAlloc(
+    /* [annotation][in] */
+    _In_opt_  void *pv)
+  {
+    return -1; // don't know
+  }
+
+
+  virtual void STDMETHODCALLTYPE HeapMinimize(void)
+  {
+  }
+};
+
 inline bool wcseq(LPCWSTR a, LPCWSTR b) {
   return (a == nullptr && b == nullptr) || (a != nullptr && b != nullptr && wcscmp(a, b) == 0);
 }
@@ -76,6 +129,8 @@ class DxcContext {
 private:
   DxcOpts &m_Opts;
   DxcDllSupport &m_dxcSupport;
+  NoSerializeHeapMalloc m_Malloc;
+  HANDLE m_MallocHeap;
 
   int ActOnBlob(IDxcBlob *pBlob);
   int ActOnBlob(IDxcBlob *pBlob, IDxcBlob *pDebugBlob, LPCWSTR pDebugBlobName);
@@ -91,9 +146,25 @@ private:
   void ExtractRootSignature(IDxcBlob *pBlob, IDxcBlob **ppResult);
   int VerifyRootSignature();
 
+  template <typename TInterface>
+  HRESULT CreateInstance(REFCLSID clsid, _Outptr_ TInterface** pResult) {
+    if (m_dxcSupport.HasCreateWithMalloc())
+      return m_dxcSupport.CreateInstance2(&m_Malloc, clsid, pResult);
+    else
+      return m_dxcSupport.CreateInstance(clsid, pResult);
+  }
+
 public:
   DxcContext(DxcOpts &Opts, DxcDllSupport &dxcSupport)
-      : m_Opts(Opts), m_dxcSupport(dxcSupport) {}
+      : m_Opts(Opts), m_dxcSupport(dxcSupport), m_MallocHeap(nullptr) {
+    if (m_dxcSupport.HasCreateWithMalloc()) {
+      m_MallocHeap = HeapCreate(HEAP_NO_SERIALIZE, 1024 * 1024 * 2, 0);
+      if (m_MallocHeap == NULL)
+        IFT_Data(HRESULT_FROM_WIN32(GetLastError()), L"unable to create custom heap");
+      m_Malloc.SetHandle(m_MallocHeap);
+      // We never free the heap because it's tied to the dxc process lifetime
+    }
+  }
 
   int  Compile();
   void Recompile(IDxcBlob *pSource, IDxcLibrary *pLibrary, IDxcCompiler *pCompiler, std::vector<LPCWSTR> &args, IDxcOperationResult **pCompileResult);
@@ -211,7 +282,7 @@ int DxcContext::ActOnBlob(IDxcBlob *pBlob, IDxcBlob *pDebugBlob, LPCWSTR pDebugB
       IFT(pLibrary->CreateBlobWithEncodingOnHeapCopy((LPBYTE)&Message[0], Message.size(), CP_ACP, &pDisassembleResult));
   } else {
       CComPtr<IDxcCompiler> pCompiler;
-      IFT(m_dxcSupport.CreateInstance(CLSID_DxcCompiler, &pCompiler));
+      IFT(CreateInstance(CLSID_DxcCompiler, &pCompiler));
       IFT(pCompiler->Disassemble(pBlob, &pDisassembleResult));
   }
   
@@ -240,7 +311,7 @@ void DxcContext::UpdatePart(IDxcBlob *pSource, IDxcBlob **ppResult) {
 
   CComPtr<IDxcContainerBuilder> pContainerBuilder;
   CComPtr<IDxcBlob> pResult;
-  IFT(m_dxcSupport.CreateInstance(CLSID_DxcContainerBuilder, &pContainerBuilder));
+  IFT(CreateInstance(CLSID_DxcContainerBuilder, &pContainerBuilder));
   
   // Load original container and update blob for each given option
   IFT(pContainerBuilder->Load(pSource));
@@ -400,7 +471,7 @@ int DxcContext::VerifyRootSignature() {
   // Since dxil container builder will verify on its behalf. 
   // This does unnecessary memory allocation. We can improve this later. 
   CComPtr<IDxcContainerBuilder> pContainerBuilder;
-  IFT(m_dxcSupport.CreateInstance(CLSID_DxcContainerBuilder, &pContainerBuilder));
+  IFT(CreateInstance(CLSID_DxcContainerBuilder, &pContainerBuilder));
   IFT(pContainerBuilder->Load(pSource));
   // Try removing root signature if it already exists
   pContainerBuilder->RemovePart(hlsl::DxilFourCC::DFCC_RootSignature);
@@ -634,8 +705,8 @@ int DxcContext::Compile() {
       args.push_back(L"-ast-dump");
 
     CComPtr<IDxcLibrary> pLibrary;
-    IFT(m_dxcSupport.CreateInstance(CLSID_DxcLibrary, &pLibrary));
-    IFT(m_dxcSupport.CreateInstance(CLSID_DxcCompiler, &pCompiler));
+    IFT(CreateInstance(CLSID_DxcLibrary, &pLibrary));
+    IFT(CreateInstance(CLSID_DxcCompiler, &pCompiler));
     ReadFileIntoBlob(m_dxcSupport, StringRefUtf16(m_Opts.InputFile), &pSource);
     IFTARG(pSource->GetBufferSize() >= 4);
 
@@ -716,11 +787,11 @@ void DxcContext::Preprocess() {
 
   CComPtr<IDxcLibrary> pLibrary;
   CComPtr<IDxcIncludeHandler> pIncludeHandler;
-  IFT(m_dxcSupport.CreateInstance(CLSID_DxcLibrary, &pLibrary));
+  IFT(CreateInstance(CLSID_DxcLibrary, &pLibrary));
   IFT(pLibrary->CreateIncludeHandler(&pIncludeHandler));
 
   ReadFileIntoBlob(m_dxcSupport, StringRefUtf16(m_Opts.InputFile), &pSource);
-  IFT(m_dxcSupport.CreateInstance(CLSID_DxcCompiler, &pCompiler));
+  IFT(CreateInstance(CLSID_DxcCompiler, &pCompiler));
   IFT(pCompiler->Preprocess(pSource, StringRefUtf16(m_Opts.InputFile), args.data(), args.size(), m_Opts.Defines.data(), m_Opts.Defines.size(), pIncludeHandler, &pPreprocessResult));
   WriteOperationErrorsToConsole(pPreprocessResult, m_Opts.OutputWarnings);
 
@@ -826,7 +897,7 @@ HRESULT DxcContext::GetDxcDiaTable(IDxcLibrary *pLibrary, IDxcBlob *pTargetBlob,
   CComPtr<IStream> pSourceStream;
   CComPtr<IDiaSession> pSession;
   CComPtr<IDiaEnumTables> pEnumTables;
-  IFT(m_dxcSupport.CreateInstance(CLSID_DxcDiaDataSource, &pDataSource));
+  IFT(CreateInstance(CLSID_DxcDiaDataSource, &pDataSource));
   IFT(pLibrary->CreateStreamFromBlobReadOnly(pTargetBlob, &pSourceStream));
   IFT(pDataSource->loadDataFromIStream(pSourceStream));
   IFT(pDataSource->openSession(&pSession));
@@ -853,9 +924,13 @@ HRESULT DxcContext::GetDxcDiaTable(IDxcLibrary *pLibrary, IDxcBlob *pTargetBlob,
 int __cdecl wmain(int argc, const wchar_t **argv_) {
   const char *pStage = "Operation";
   int retVal = 0;
+  if (FAILED(DxcInitThreadMalloc())) return 1;
+  DxcSetThreadMallocOrDefault(nullptr);
   try {
     pStage = "Argument processing";
 
+    if (initHlslOptTable()) throw std::bad_alloc();
+
     // Parse command line options.
     const OptTable *optionTable = getHlslOptTable();
     MainArgs argStrings(argc, argv_);

+ 64 - 17
tools/clang/tools/dxcompiler/DXCompiler.cpp

@@ -11,47 +11,94 @@
 
 #include "llvm/Support/ManagedStatic.h"
 #include "llvm/Support/FileSystem.h"
-
+#include "dxc/Support/Global.h"
 #include "dxc/Support/WinIncludes.h"
+#include "dxc/Support/HLSLOptions.h"
 #include "dxcetw.h"
 #include "dxillib.h"
 
 namespace hlsl { HRESULT SetupRegistryPassForHLSL(); }
 
+// C++ exception specification ignored except to indicate a function is not __declspec(nothrow)
+#pragma warning( disable : 4290 )
+
+// operator new and friends.
+void *operator new(std::size_t size) throw(std::bad_alloc) {
+  void * ptr = DxcGetThreadMallocNoRef()->Alloc(size);
+  if (ptr == nullptr)
+    throw std::bad_alloc();
+  return ptr;
+}
+void *operator new(std::size_t size,
+  const std::nothrow_t &nothrow_value) throw() {
+  return DxcGetThreadMallocNoRef()->Alloc(size);
+}
+void operator delete (void* ptr) throw() {
+  DxcGetThreadMallocNoRef()->Free(ptr);
+}
+void operator delete (void* ptr, const std::nothrow_t& nothrow_constant) throw() {
+  DxcGetThreadMallocNoRef()->Free(ptr);
+}
+
+static HRESULT InitMaybeFail() throw() {
+  HRESULT hr;
+  bool fsSetup = false, memSetup = false;
+  IFC(DxcInitThreadMalloc());
+  DxcSetThreadMallocOrDefault(nullptr);
+  memSetup = true;
+  if (::llvm::sys::fs::SetupPerThreadFileSystem()) {
+    hr = E_FAIL;
+    goto Cleanup;
+  }
+  fsSetup = true;
+  IFC(hlsl::SetupRegistryPassForHLSL());
+  IFC(DxilLibInitialize());
+  if (hlsl::options::initHlslOptTable()) {
+    hr = E_FAIL;
+    goto Cleanup;
+  }
+Cleanup:
+  if (FAILED(hr)) {
+    if (fsSetup) {
+      ::llvm::sys::fs::CleanupPerThreadFileSystem();
+    }
+    if (memSetup) {
+      DxcClearThreadMalloc();
+      DxcCleanupThreadMalloc();
+    }
+  }
+  else {
+    DxcClearThreadMalloc();
+  }
+  return hr;
+}
+
 BOOL WINAPI DllMain(HINSTANCE hinstDLL, DWORD Reason, LPVOID reserved) {
   BOOL result = TRUE;
   if (Reason == DLL_PROCESS_ATTACH) {
     EventRegisterMicrosoft_Windows_DXCompiler_API();
     DxcEtw_DXCompilerInitialization_Start();
     DisableThreadLibraryCalls(hinstDLL);
-    HRESULT hr = S_OK;
-    if (::llvm::sys::fs::SetupPerThreadFileSystem()) {
-      hr = E_FAIL;
-    }
-    else {
-      hr = hlsl::SetupRegistryPassForHLSL();
-      if (SUCCEEDED(hr)) {
-        DxilLibInitialize();
-      }
-      else {
-        ::llvm::sys::fs::CleanupPerThreadFileSystem();
-      }
-    }
+    HRESULT hr = InitMaybeFail();
     DxcEtw_DXCompilerInitialization_Stop(hr);
     result = SUCCEEDED(hr) ? TRUE : FALSE;
   } else if (Reason == DLL_PROCESS_DETACH) {
     DxcEtw_DXCompilerShutdown_Start();
+    DxcSetThreadMallocOrDefault(nullptr);
+    ::hlsl::options::cleanupHlslOptTable();
     ::llvm::sys::fs::CleanupPerThreadFileSystem();
     ::llvm::llvm_shutdown();
-    DxcEtw_DXCompilerShutdown_Stop(S_OK);
-    EventUnregisterMicrosoft_Windows_DXCompiler_API();
     if (reserved == NULL) { // FreeLibrary has been called or the DLL load failed
       DxilLibCleanup(DxilLibCleanUpType::UnloadLibrary);
     }
     else { // Process termination. We should not call FreeLibrary()
       DxilLibCleanup(DxilLibCleanUpType::ProcessTermination);
     }
-  } 
+    DxcClearThreadMalloc();
+    DxcCleanupThreadMalloc();
+    DxcEtw_DXCompilerShutdown_Stop(S_OK);
+    EventUnregisterMicrosoft_Windows_DXCompiler_API();
+  }
 
   return result;
 }

+ 1 - 0
tools/clang/tools/dxcompiler/DXCompiler.def

@@ -2,3 +2,4 @@ LIBRARY dxcompiler
 
 EXPORTS
     DxcCreateInstance
+    DxcCreateInstance2

+ 36 - 16
tools/clang/tools/dxcompiler/dxcapi.cpp

@@ -15,6 +15,7 @@
 
 #include "dxc/dxcisense.h"
 #include "dxc/dxctools.h"
+#include "dxc/Support/Global.h"
 #include "dxcetw.h"
 #include "dxillib.h"
 #include <memory>
@@ -46,25 +47,11 @@ HRESULT CreateDxcContainerReflection(_In_ REFIID riid, _Out_ LPVOID *ppv) {
   }
 }
 
-/// <summary>
-/// Creates a single uninitialized object of the class associated with a specified CLSID.
-/// </summary>
-/// <param name="rclsid">The CLSID associated with the data and code that will be used to create the object.</param>
-/// <param name="riid">A reference to the identifier of the interface to be used to communicate with the object.</param>
-/// <param name="ppv">Address of pointer variable that receives the interface pointer requested in riid. Upon successful return, *ppv contains the requested interface pointer. Upon failure, *ppv contains NULL.</param>
-/// <remarks>
-/// While this function is similar to CoCreateInstance, there is no COM involvement.  
-/// </remarks>
-DXC_API_IMPORT HRESULT __stdcall
-DxcCreateInstance(_In_ REFCLSID   rclsid,
+static HRESULT ThreadMallocDxcCreateInstance(
+  _In_ REFCLSID   rclsid,
                   _In_ REFIID     riid,
                   _Out_ LPVOID   *ppv) {
-  if (ppv == nullptr) {
-    return E_POINTER;
-  }
-
   HRESULT hr = S_OK;
-  DxcEtw_DXCompilerCreateInstance_Start();
   *ppv = nullptr;
   if (IsEqualCLSID(rclsid, CLSID_DxcIntelliSense)) {
     hr = CreateDxcIntelliSense(riid, ppv);
@@ -106,7 +93,40 @@ DxcCreateInstance(_In_ REFCLSID   rclsid,
   else {
     hr = REGDB_E_CLASSNOTREG;
   }
+  return hr;
+}
+
+DXC_API_IMPORT HRESULT __stdcall
+DxcCreateInstance(
+  _In_ REFCLSID   rclsid,
+  _In_ REFIID     riid,
+  _Out_ LPVOID   *ppv) {
+  if (ppv == nullptr) {
+    return E_POINTER;
+  }
 
+  HRESULT hr = S_OK;
+  DxcEtw_DXCompilerCreateInstance_Start();
+  DxcThreadMalloc TM(nullptr);
+  hr = ThreadMallocDxcCreateInstance(rclsid, riid, ppv);
+  DxcEtw_DXCompilerCreateInstance_Stop(hr);
+  return hr;
+}
+
+DXC_API_IMPORT HRESULT __stdcall
+DxcCreateInstance2(
+  _In_ IMalloc    *pMalloc,
+  _In_ REFCLSID   rclsid,
+  _In_ REFIID     riid,
+  _Out_ LPVOID   *ppv) {
+  if (ppv == nullptr) {
+    return E_POINTER;
+  }
+
+  HRESULT hr = S_OK;
+  DxcEtw_DXCompilerCreateInstance_Start();
+  DxcThreadMalloc TM(pMalloc);
+  hr = ThreadMallocDxcCreateInstance(rclsid, riid, ppv);
   DxcEtw_DXCompilerCreateInstance_Stop(hr);
   return hr;
 }

이 변경점에서 너무 많은 파일들이 변경되어 몇몇 파일들은 표시되지 않았습니다.