Bladeren bron

Added pass to remove regions with no escaping values or side effects. (#2508)

* Erase dead region

* Pass dependencies

* Simpler heuristic, only checking that Begin dominates End and End post dominates Begin

* Small cleanups. No longer iterating whole block to find PHIs

* A few optimizations. Fixed infinite loops caused by self-loops
Adam Yang 5 jaren geleden
bovenliggende
commit
97ec60accd

+ 1 - 0
include/dxc/HLSL/DxilGenerationPass.h

@@ -85,6 +85,7 @@ void initializeHLEnsureMetadataPass(llvm::PassRegistry&);
 void initializeHLEmitMetadataPass(llvm::PassRegistry&);
 void initializeDxilFinalizeModulePass(llvm::PassRegistry&);
 void initializeDxilEmitMetadataPass(llvm::PassRegistry&);
+void initializeDxilEraseDeadRegionPass(llvm::PassRegistry&);
 void initializeDxilExpandTrigIntrinsicsPass(llvm::PassRegistry&);
 void initializeDxilDeadFunctionEliminationPass(llvm::PassRegistry&);
 void initializeHLDeadFunctionEliminationPass(llvm::PassRegistry&);

+ 4 - 0
include/llvm/Transforms/Scalar.h

@@ -139,6 +139,10 @@ void initializeDxilConditionalMem2RegPass(PassRegistry&);
 
 Pass *createDxilLoopUnrollPass(unsigned MaxIterationAttempt);
 void initializeDxilLoopUnrollPass(PassRegistry&);
+
+Pass *createDxilEraseDeadRegionPass();
+void initializeDxilEraseDeadRegionPass(PassRegistry&);
+
 //===----------------------------------------------------------------------===//
 //
 // LowerStaticGlobalIntoAlloca. Replace static globals with alloca if only used

+ 1 - 0
lib/HLSL/DxcOptimizer.cpp

@@ -93,6 +93,7 @@ HRESULT SetupRegistryPassForHLSL() {
     initializeDxilDeadFunctionEliminationPass(Registry);
     initializeDxilEliminateOutputDynamicIndexingPass(Registry);
     initializeDxilEmitMetadataPass(Registry);
+    initializeDxilEraseDeadRegionPass(Registry);
     initializeDxilExpandTrigIntrinsicsPass(Registry);
     initializeDxilFinalizeModulePass(Registry);
     initializeDxilFixConstArrayInitializerPass(Registry);

+ 3 - 0
lib/Transforms/IPO/PassManagerBuilder.cpp

@@ -614,6 +614,9 @@ void PassManagerBuilder::populateModulePassManager(
 
   // HLSL Change Begins.
   if (!HLSLHighLevel) {
+    if (OptLevel > 0)
+      MPM.add(createDxilEraseDeadRegionPass());
+
     MPM.add(createDxilConvergentClearPass());
     MPM.add(createDeadCodeEliminationPass()); // DCE needed after clearing convergence
                                               // annotations before CreateHandleForLib

+ 1 - 0
lib/Transforms/Scalar/CMakeLists.txt

@@ -46,6 +46,7 @@ add_llvm_library(LLVMScalarOpts
   ScalarReplAggregates.cpp
   ScalarReplAggregatesHLSL.cpp  # HLSL Change
   DxilLoopUnroll.cpp # HLSL Change
+  DxilEraseDeadRegion.cpp # HLSL Change
   DxilFixConstArrayInitializer.cpp # HLSL Change
   Scalarizer.cpp
   SeparateConstOffsetFromGEP.cpp

+ 192 - 0
lib/Transforms/Scalar/DxilEraseDeadRegion.cpp

@@ -0,0 +1,192 @@
+//===- DxilEraseDeadRegion.cpp - Heuristically Remove Dead Region ---------===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+// Overview:
+//   1. Identify potentially dead regions by finding blocks with multiple
+//      predecessors but no PHIs
+//   2. Find common dominant ancestor of all the predecessors
+//   3. Ensure original block post-dominates the ancestor
+//   4. Ensure no instructions in the region have side effects (not including
+//      original block and ancestor)
+//   5. Remove all blocks in the region (excluding original block and ancestor)
+//
+
+#include "llvm/Pass.h"
+#include "llvm/Analysis/CFG.h"
+#include "llvm/Analysis/PostDominators.h"
+#include "llvm/Transforms/Scalar.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/BasicBlock.h"
+
+#include <unordered_map>
+#include <unordered_set>
+
+using namespace llvm;
+
+struct DxilEraseDeadRegion : public FunctionPass {
+  static char ID;
+
+  DxilEraseDeadRegion() : FunctionPass(ID) {
+    initializeDxilEraseDeadRegionPass(*PassRegistry::getPassRegistry());
+  }
+
+  std::unordered_map<BasicBlock *, bool> m_HasSideEffect;
+
+  bool HasSideEffects(BasicBlock *BB) {
+    auto FindIt = m_HasSideEffect.find(BB);
+    if (FindIt != m_HasSideEffect.end()) {
+      return FindIt->second;
+    }
+
+    for (Instruction &I : *BB)
+      if (I.mayHaveSideEffects()) {
+        m_HasSideEffect[BB] = true;
+        return true;
+      }
+
+    m_HasSideEffect[BB] = false;
+    return false;
+  }
+
+  bool FindDeadRegion(PostDominatorTree *PDT, BasicBlock *Begin, BasicBlock *End, std::set<BasicBlock *> &Region) {
+    std::vector<BasicBlock *> WorkList;
+    auto ProcessSuccessors = [this, &WorkList, Begin, End, &Region, PDT](BasicBlock *BB) {
+      for (BasicBlock *Succ : successors(BB)) {
+        if (Succ == End) continue;
+        if (Succ == Begin) return false; // If goes back to the beginning, there's a loop, give up.
+        if (Region.count(Succ)) continue;
+        if (this->HasSideEffects(Succ)) return false; // Give up if the block may have side effects
+
+        WorkList.push_back(Succ);
+        Region.insert(Succ);
+      }
+      return true;
+    };
+
+    if (!ProcessSuccessors(Begin))
+      return false;
+
+    while (WorkList.size()) {
+      BasicBlock *BB = WorkList.back();
+      WorkList.pop_back();
+      if (!ProcessSuccessors(BB))
+        return false;
+    }
+
+    return true;
+  }
+
+  bool TrySimplify(DominatorTree *DT, PostDominatorTree *PDT, BasicBlock *BB) {
+    // Give up if BB has any Phis
+    if (BB->begin() != BB->end() && isa<PHINode>(BB->begin()))
+      return false;
+
+    std::vector<BasicBlock *> Predecessors(pred_begin(BB), pred_end(BB));
+    if (Predecessors.size() < 2) return false;
+
+    // Give up if BB is a self loop
+    for (BasicBlock *PredBB : Predecessors)
+      if (PredBB == BB)
+        return false;
+
+    // Find the common ancestor of all the predecessors
+    BasicBlock *Common = DT->findNearestCommonDominator(Predecessors[0], Predecessors[1]);
+    if (!Common) return false;
+    for (unsigned i = 2; i < Predecessors.size(); i++) {
+      Common = DT->findNearestCommonDominator(Common, Predecessors[i]);
+      if (!Common) return false;
+    }
+
+   // If there are any metadata on Common block's branch, give up.
+    if (Common->getTerminator()->hasMetadataOtherThanDebugLoc())
+      return false;
+
+    if (!DT->properlyDominates(Common, BB))
+      return false;
+    if (!PDT->properlyDominates(BB, Common))
+      return false;
+
+    std::set<BasicBlock *> Region;
+    if (!this->FindDeadRegion(PDT, Common, BB, Region))
+      return false;
+
+    // If BB branches INTO the region, forming a loop give up.
+    for (BasicBlock *Succ : successors(BB))
+      if (Region.count(Succ))
+        return false;
+
+    // Replace Common's branch with an unconditional branch to BB
+    Common->getTerminator()->eraseFromParent();
+    BranchInst::Create(BB, Common);
+
+    // Delete the region
+    for (BasicBlock *BB : Region) {
+      for (Instruction &I : *BB)
+        I.dropAllReferences();
+      BB->dropAllReferences();
+    }
+    for (BasicBlock *BB : Region) {
+      while (BB->begin() != BB->end())
+        BB->begin()->eraseFromParent();
+      BB->eraseFromParent();
+    }
+
+    return true;
+  }
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.addRequired<DominatorTreeWrapperPass>();
+    AU.addRequired<PostDominatorTree>();
+  }
+
+  bool runOnFunction(Function &F) override {
+    auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+    auto *PDT = &getAnalysis<PostDominatorTree>();
+
+    std::unordered_set<BasicBlock *> FailedSet;
+    bool Changed = false;
+    while (1) {
+      bool LocalChanged = false;
+      for (Function::iterator It = F.begin(), E = F.end(); It != E; It++) {
+        BasicBlock &BB = *It;
+        if (FailedSet.count(&BB))
+          continue;
+
+        if (this->TrySimplify(DT, PDT, &BB)) {
+          LocalChanged = true;
+          break;
+        }
+        else {
+          FailedSet.insert(&BB);
+        }
+      }
+
+      Changed |= LocalChanged;
+      if (!LocalChanged)
+        break;
+    }
+
+    return Changed;
+  }
+};
+
+char DxilEraseDeadRegion::ID;
+
+Pass *llvm::createDxilEraseDeadRegionPass() {
+  return new DxilEraseDeadRegion();
+}
+
+INITIALIZE_PASS_BEGIN(DxilEraseDeadRegion, "dxil-erase-dead-region", "Dxil Erase Dead Region", false, false)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(PostDominatorTree)
+INITIALIZE_PASS_END(DxilEraseDeadRegion, "dxil-erase-dead-region", "Dxil Erase Dead Region", false, false)
+
+

+ 1 - 0
utils/hct/hctdb.py

@@ -2003,6 +2003,7 @@ class db_dxil(object):
         add_pass('indvars', 'IndVarSimplify', "Induction Variable Simplification", [])
         add_pass('loop-idiom', 'LoopIdiomRecognize', "Recognize loop idioms", [])
         add_pass('dxil-loop-unroll', 'DxilLoopUnroll', 'DxilLoopUnroll', [])
+        add_pass('dxil-erase-dead-region', 'DxilEraseDeadRegion', 'DxilEraseDeadRegion', [])
         add_pass('loop-deletion', 'LoopDeletion', "Delete dead loops", [])
         add_pass('loop-interchange', 'LoopInterchange', 'Interchanges loops for cache reuse', [])
         add_pass('loop-unroll', 'LoopUnroll', 'Unroll loops', [