Sfoglia il codice sorgente

Fix remaining HLModule/DxilModule dependencies from llvm::Module, etc. (#1656)

- Use one callback for global removal - Function and GlobalValue use this
- Share callback for HLModule and DxilModule, transfer is handled by
  setting callback on construction and checking pointer before clearing
Tex Riddell 6 anni fa
parent
commit
532368e9a1
5 ha cambiato i file con 42 aggiunte e 37 eliminazioni
  1. 9 8
      include/llvm/IR/Module.h
  2. 13 12
      lib/DXIL/DxilModule.cpp
  3. 16 12
      lib/HLSL/HLModule.cpp
  4. 2 2
      lib/IR/Function.cpp
  5. 2 3
      lib/IR/Globals.cpp

+ 9 - 8
include/llvm/IR/Module.h

@@ -690,23 +690,24 @@ public:
 /// @}
 
   // HLSL Change start
-  typedef void (*RemoveFunctionCallback)(llvm::Module*, llvm::Function*);
-  RemoveFunctionCallback   pHLModuleRemoveFunction = nullptr;
-  RemoveFunctionCallback pDxilModuleRemoveFunction = nullptr;
-  void RemoveFunctionHook(llvm::Function* F) {
-    if   (pHLModuleRemoveFunction)   (*pHLModuleRemoveFunction)(this, F);
-    if (pDxilModuleRemoveFunction) (*pDxilModuleRemoveFunction)(this, F);
+  typedef void (*RemoveGlobalCallback)(llvm::Module*, llvm::GlobalObject*);
+  typedef void(*ResetModuleCallback)(llvm::Module*);
+  RemoveGlobalCallback pfnRemoveGlobal = nullptr;
+  void CallRemoveGlobalHook(llvm::GlobalObject* G) {
+    if (pfnRemoveGlobal) (*pfnRemoveGlobal)(this, G);
   }
   bool HasHLModule() const { return TheHLModule != nullptr; }
   void SetHLModule(hlsl::HLModule *pValue) { TheHLModule = pValue; }
   hlsl::HLModule &GetHLModule() { return *TheHLModule; }
   hlsl::HLModule &GetOrCreateHLModule(bool skipInit = false);
-  void ResetHLModule();
+  ResetModuleCallback pfnResetHLModule = nullptr;
+  void ResetHLModule() { if (pfnResetHLModule) (*pfnResetHLModule)(this); }
   bool HasDxilModule() const { return TheDxilModule != nullptr; }
   void SetDxilModule(hlsl::DxilModule *pValue) { TheDxilModule = pValue; }
   hlsl::DxilModule &GetDxilModule() const { return *TheDxilModule; }
   hlsl::DxilModule &GetOrCreateDxilModule(bool skipInit = false);
-  void ResetDxilModule();
+  ResetModuleCallback pfnResetDxilModule = nullptr;
+  void ResetDxilModule() { if (pfnResetDxilModule) (*pfnResetDxilModule)(this); }
   // HLSL Change end
 };
 

+ 13 - 12
lib/DXIL/DxilModule.cpp

@@ -76,9 +76,16 @@ const char* kFP32DenormValueFtzString      = "ftz";
 }
 
 // Avoid dependency on DxilModule from llvm::Module using this:
-void DxilModule_RemoveFunction(llvm::Module* M, llvm::Function* F) {
-  if (M && F && M->HasDxilModule())
-    M->GetDxilModule().RemoveFunction(F);
+void DxilModule_RemoveGlobal(llvm::Module* M, llvm::GlobalObject* G) {
+  if (M && G && M->HasDxilModule()) {
+    if (llvm::Function *F = dyn_cast<llvm::Function>(G))
+      M->GetDxilModule().RemoveFunction(F);
+  }
+}
+void DxilModule_ResetModule(llvm::Module* M) {
+  if (M && M->HasDxilModule())
+    delete &M->GetDxilModule();
+  M->SetDxilModule(nullptr);
 }
 
 //------------------------------------------------------------------------------
@@ -109,7 +116,7 @@ DxilModule::DxilModule(Module *pModule)
 {
 
   DXASSERT_NOMSG(m_pModule != nullptr);
-  m_pModule->pDxilModuleRemoveFunction = &DxilModule_RemoveFunction;
+  m_pModule->pfnRemoveGlobal = &DxilModule_RemoveGlobal;
 
 #if defined(_DEBUG) || defined(DBG)
   // Pin LLVM dump methods.
@@ -122,7 +129,8 @@ DxilModule::DxilModule(Module *pModule)
 }
 
 DxilModule::~DxilModule() {
-  m_pModule->pDxilModuleRemoveFunction = nullptr;
+  if (m_pModule->pfnRemoveGlobal == &DxilModule_RemoveGlobal)
+    m_pModule->pfnRemoveGlobal = nullptr;
 }
 
 LLVMContext &DxilModule::GetCtx() const { return m_Ctx; }
@@ -1579,11 +1587,4 @@ hlsl::DxilModule &Module::GetOrCreateDxilModule(bool skipInit) {
   return GetDxilModule();
 }
 
-void Module::ResetDxilModule() {
-  if (HasDxilModule()) {
-    delete TheDxilModule;
-    TheDxilModule = nullptr;
-  }
-}
-
 }

+ 16 - 12
lib/HLSL/HLModule.cpp

@@ -36,9 +36,18 @@ using std::unique_ptr;
 namespace hlsl {
 
 // Avoid dependency on HLModule from llvm::Module using this:
-void HLModule_RemoveFunction(llvm::Module* M, llvm::Function* F) {
-  if (M && F && M->HasHLModule())
-    M->GetHLModule().RemoveFunction(F);
+void HLModule_RemoveGlobal(llvm::Module* M, llvm::GlobalObject* G) {
+  if (M && G && M->HasHLModule()) {
+    if (llvm::GlobalVariable *GV = dyn_cast<llvm::GlobalVariable>(G))
+      M->GetHLModule().RemoveGlobal(GV);
+    else if (llvm::Function *F = dyn_cast<llvm::Function>(G))
+      M->GetHLModule().RemoveFunction(F);
+  }
+}
+void HLModule_ResetModule(llvm::Module* M) {
+  if (M && M->HasHLModule())
+    delete &M->GetHLModule();
+  M->SetHLModule(nullptr);
 }
 
 //------------------------------------------------------------------------------
@@ -64,7 +73,8 @@ HLModule::HLModule(Module *pModule)
     , m_DefaultLinkage(DXIL::DefaultLinkage::Default)
     , m_pTypeSystem(llvm::make_unique<DxilTypeSystem>(pModule)) {
   DXASSERT_NOMSG(m_pModule != nullptr);
-  m_pModule->pHLModuleRemoveFunction = &HLModule_RemoveFunction;
+  m_pModule->pfnRemoveGlobal = &HLModule_RemoveGlobal;
+  m_pModule->pfnResetHLModule = &HLModule_ResetModule;
 
   // Pin LLVM dump methods. TODO: make debug-only.
   void (__thiscall Module::*pfnModuleDump)() const = &Module::dump;
@@ -73,7 +83,8 @@ HLModule::HLModule(Module *pModule)
 }
 
 HLModule::~HLModule() {
-  m_pModule->pHLModuleRemoveFunction = nullptr;
+  if (m_pModule->pfnRemoveGlobal == &HLModule_RemoveGlobal)
+    m_pModule->pfnRemoveGlobal = nullptr;
 }
 
 LLVMContext &HLModule::GetCtx() const { return m_Ctx; }
@@ -1312,11 +1323,4 @@ hlsl::HLModule &Module::GetOrCreateHLModule(bool skipInit) {
   return GetHLModule();
 }
 
-void Module::ResetHLModule() {
-  if (HasHLModule()) {
-    delete TheHLModule;
-    TheHLModule = nullptr;
-  }
-}
-
 }

+ 2 - 2
lib/IR/Function.cpp

@@ -235,12 +235,12 @@ Type *Function::getReturnType() const {
 }
 
 void Function::removeFromParent() {
-  getParent()->RemoveFunctionHook(this); // HLSL Change
+  getParent()->CallRemoveGlobalHook(this); // HLSL Change
   getParent()->getFunctionList().remove(this);
 }
 
 void Function::eraseFromParent() {
-  getParent()->RemoveFunctionHook(this); // HLSL Change
+  getParent()->CallRemoveGlobalHook(this); // HLSL Change
   getParent()->getFunctionList().erase(this);
 }
 

+ 2 - 3
lib/IR/Globals.cpp

@@ -19,7 +19,6 @@
 #include "llvm/IR/GlobalAlias.h"
 #include "llvm/IR/GlobalVariable.h"
 #include "llvm/IR/Module.h"
-#include "dxc/HLSL/HLModule.h" // HLSL Change
 #include "llvm/IR/Operator.h"
 #include "llvm/Support/ErrorHandling.h"
 using namespace llvm;
@@ -189,12 +188,12 @@ void GlobalVariable::setParent(Module *parent) {
 }
 
 void GlobalVariable::removeFromParent() {
-  if (getParent()->HasHLModule()) getParent()->GetHLModule().RemoveGlobal(this);
+  getParent()->CallRemoveGlobalHook(this);  // HLSL Change
   getParent()->getGlobalList().remove(this);
 }
 
 void GlobalVariable::eraseFromParent() {
-  if (getParent()->HasHLModule()) getParent()->GetHLModule().RemoveGlobal(this);
+  getParent()->CallRemoveGlobalHook(this);  // HLSL Change
   getParent()->getGlobalList().erase(this);
 }