DxilEraseDeadRegion.cpp 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. //===- DxilEraseDeadRegion.cpp - Heuristically Remove Dead Region ---------===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. // Overview:
  10. // 1. Identify potentially dead regions by finding blocks with multiple
  11. // predecessors but no PHIs
  12. // 2. Find common dominant ancestor of all the predecessors
  13. // 3. Ensure original block post-dominates the ancestor
  14. // 4. Ensure no instructions in the region have side effects (not including
  15. // original block and ancestor)
  16. // 5. Remove all blocks in the region (excluding original block and ancestor)
  17. //
  18. #include "llvm/Pass.h"
  19. #include "llvm/Analysis/CFG.h"
  20. #include "llvm/Analysis/PostDominators.h"
  21. #include "llvm/Transforms/Scalar.h"
  22. #include "llvm/IR/Instructions.h"
  23. #include "llvm/IR/IntrinsicInst.h"
  24. #include "llvm/IR/Function.h"
  25. #include "llvm/IR/BasicBlock.h"
  26. #include <unordered_map>
  27. #include <unordered_set>
  28. using namespace llvm;
  29. struct DxilEraseDeadRegion : public FunctionPass {
  30. static char ID;
  31. DxilEraseDeadRegion() : FunctionPass(ID) {
  32. initializeDxilEraseDeadRegionPass(*PassRegistry::getPassRegistry());
  33. }
  34. std::unordered_map<BasicBlock *, bool> m_HasSideEffect;
  35. bool HasSideEffects(BasicBlock *BB) {
  36. auto FindIt = m_HasSideEffect.find(BB);
  37. if (FindIt != m_HasSideEffect.end()) {
  38. return FindIt->second;
  39. }
  40. for (Instruction &I : *BB)
  41. if (I.mayHaveSideEffects()) {
  42. m_HasSideEffect[BB] = true;
  43. return true;
  44. }
  45. m_HasSideEffect[BB] = false;
  46. return false;
  47. }
  48. bool FindDeadRegion(BasicBlock *Begin, BasicBlock *End, std::set<BasicBlock *> &Region) {
  49. std::vector<BasicBlock *> WorkList;
  50. auto ProcessSuccessors = [this, &WorkList, Begin, End, &Region](BasicBlock *BB) {
  51. for (BasicBlock *Succ : successors(BB)) {
  52. if (Succ == End) continue;
  53. if (Succ == Begin) return false; // If goes back to the beginning, there's a loop, give up.
  54. if (Region.count(Succ)) continue;
  55. if (this->HasSideEffects(Succ)) return false; // Give up if the block may have side effects
  56. WorkList.push_back(Succ);
  57. Region.insert(Succ);
  58. }
  59. return true;
  60. };
  61. if (!ProcessSuccessors(Begin))
  62. return false;
  63. while (WorkList.size()) {
  64. BasicBlock *BB = WorkList.back();
  65. WorkList.pop_back();
  66. if (!ProcessSuccessors(BB))
  67. return false;
  68. }
  69. return Region.size() != 0;
  70. }
  71. bool TrySimplify(DominatorTree *DT, PostDominatorTree *PDT, BasicBlock *BB) {
  72. // Give up if BB has any Phis
  73. if (BB->begin() != BB->end() && isa<PHINode>(BB->begin()))
  74. return false;
  75. std::vector<BasicBlock *> Predecessors(pred_begin(BB), pred_end(BB));
  76. if (Predecessors.size() < 2) return false;
  77. // Give up if BB is a self loop
  78. for (BasicBlock *PredBB : Predecessors)
  79. if (PredBB == BB)
  80. return false;
  81. // Find the common ancestor of all the predecessors
  82. BasicBlock *Common = DT->findNearestCommonDominator(Predecessors[0], Predecessors[1]);
  83. if (!Common) return false;
  84. for (unsigned i = 2; i < Predecessors.size(); i++) {
  85. Common = DT->findNearestCommonDominator(Common, Predecessors[i]);
  86. if (!Common) return false;
  87. }
  88. // If there are any metadata on Common block's branch, give up.
  89. if (Common->getTerminator()->hasMetadataOtherThanDebugLoc())
  90. return false;
  91. if (!DT->properlyDominates(Common, BB))
  92. return false;
  93. if (!PDT->properlyDominates(BB, Common))
  94. return false;
  95. std::set<BasicBlock *> Region;
  96. if (!this->FindDeadRegion(Common, BB, Region))
  97. return false;
  98. // If BB branches INTO the region, forming a loop give up.
  99. for (BasicBlock *Succ : successors(BB))
  100. if (Region.count(Succ))
  101. return false;
  102. // Replace Common's branch with an unconditional branch to BB
  103. Common->getTerminator()->eraseFromParent();
  104. BranchInst::Create(BB, Common);
  105. // Delete the region
  106. for (BasicBlock *BB : Region) {
  107. for (Instruction &I : *BB)
  108. I.dropAllReferences();
  109. BB->dropAllReferences();
  110. }
  111. for (BasicBlock *BB : Region) {
  112. while (BB->begin() != BB->end())
  113. BB->begin()->eraseFromParent();
  114. BB->eraseFromParent();
  115. }
  116. return true;
  117. }
  118. void getAnalysisUsage(AnalysisUsage &AU) const override {
  119. AU.addRequired<DominatorTreeWrapperPass>();
  120. AU.addRequired<PostDominatorTree>();
  121. }
  122. bool runOnFunction(Function &F) override {
  123. auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
  124. auto *PDT = &getAnalysis<PostDominatorTree>();
  125. std::unordered_set<BasicBlock *> FailedSet;
  126. bool Changed = false;
  127. while (1) {
  128. bool LocalChanged = false;
  129. for (Function::iterator It = F.begin(), E = F.end(); It != E; It++) {
  130. BasicBlock &BB = *It;
  131. if (FailedSet.count(&BB))
  132. continue;
  133. if (this->TrySimplify(DT, PDT, &BB)) {
  134. LocalChanged = true;
  135. break;
  136. }
  137. else {
  138. FailedSet.insert(&BB);
  139. }
  140. }
  141. Changed |= LocalChanged;
  142. if (!LocalChanged)
  143. break;
  144. }
  145. return Changed;
  146. }
  147. };
  148. char DxilEraseDeadRegion::ID;
  149. Pass *llvm::createDxilEraseDeadRegionPass() {
  150. return new DxilEraseDeadRegion();
  151. }
  152. INITIALIZE_PASS_BEGIN(DxilEraseDeadRegion, "dxil-erase-dead-region", "Dxil Erase Dead Region", false, false)
  153. INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
  154. INITIALIZE_PASS_DEPENDENCY(PostDominatorTree)
  155. INITIALIZE_PASS_END(DxilEraseDeadRegion, "dxil-erase-dead-region", "Dxil Erase Dead Region", false, false)