Просмотр исходного кода

Simplifying branches to reduce complexity of cfg. (#2931)

Adam Yang 5 лет назад
Родитель
Сommit
598f1d14bf

+ 1 - 0
include/dxc/DXIL/DxilConstants.h

@@ -1440,6 +1440,7 @@ namespace DXIL {
 
   static const char *kDxBreakFuncName = "dx.break";
   static const char *kDxBreakCondName = "dx.break.cond";
+  static const char *kDxBreakMDName = "dx.break.br";
 
 } // namespace DXIL
 

+ 2 - 0
include/llvm/Analysis/DxilValueCache.h

@@ -18,6 +18,7 @@ namespace llvm {
 class Module;
 class DominatorTree;
 class Constant;
+class ConstantInt;
 
 struct DxilValueCache : public ImmutablePass {
   static char ID;
@@ -73,6 +74,7 @@ public:
   void dump() const;
   Value *GetValue(Value *V, DominatorTree *DT=nullptr);
   Constant *GetConstValue(Value *V, DominatorTree *DT = nullptr);
+  ConstantInt *GetConstInt(Value *V, DominatorTree *DT = nullptr);
   void ResetUnknowns() { ValueMap.ResetUnknowns(); }
   bool IsAlwaysReachable(BasicBlock *BB, DominatorTree *DT=nullptr);
   bool IsUnreachable(BasicBlock *BB, DominatorTree *DT=nullptr);

+ 8 - 0
lib/Analysis/DxilValueCache.cpp

@@ -385,6 +385,8 @@ const char *DxilValueCache::getPassName() const {
 }
 
 Value *DxilValueCache::GetValue(Value *V, DominatorTree *DT) {
+  if (dyn_cast<Constant>(V))
+    return V;
   if (Value *NewV = ValueMap.Get(V))
     return NewV;
   return ProcessValue(V, DT);
@@ -396,6 +398,12 @@ Constant *DxilValueCache::GetConstValue(Value *V, DominatorTree *DT) {
   return nullptr;
 }
 
+ConstantInt *DxilValueCache::GetConstInt(Value *V, DominatorTree *DT) {
+  if (Value *NewV = GetValue(V))
+    return dyn_cast<ConstantInt>(NewV);
+  return nullptr;
+}
+
 bool DxilValueCache::IsAlwaysReachable(BasicBlock *BB, DominatorTree *DT) {
   ProcessValue(BB, DT);
   return IsAlwaysReachable_(BB);

+ 8 - 0
lib/HLSL/DxilPreparePasses.cpp

@@ -662,6 +662,14 @@ private:
         }
       }
       BreakFunc->eraseFromParent();
+
+      for (Function &F : M) {
+        for (BasicBlock &BB : F) {
+          if (BranchInst *BI = dyn_cast<BranchInst>(BB.getTerminator())) {
+            BI->setMetadata(DXIL::kDxBreakMDName, nullptr);
+          }
+        }
+      }
     }
   }
 };

+ 36 - 19
lib/Transforms/Scalar/DxilRemoveDeadBlocks.cpp

@@ -24,10 +24,23 @@
 
 using namespace llvm;
 
+static void RemoveIncomingValueFrom(BasicBlock *SuccBB, BasicBlock *BB) {
+  for (auto inst_it = SuccBB->begin(); inst_it != SuccBB->end();) {
+    Instruction *I = &*(inst_it++);
+    if (PHINode *PN = dyn_cast<PHINode>(I))
+      PN->removeIncomingValue(BB, true);
+    else
+      break;
+  }
+}
+
+
 static bool EraseDeadBlocks(Function &F, DxilValueCache *DVC) {
   std::unordered_set<BasicBlock *> Seen;
   std::vector<BasicBlock *> WorkList;
 
+  bool Changed = false;
+
   auto Add = [&WorkList, &Seen](BasicBlock *BB) {
     if (!Seen.count(BB)) {
       WorkList.push_back(BB);
@@ -48,17 +61,27 @@ static bool EraseDeadBlocks(Function &F, DxilValueCache *DVC) {
         Add(Succ);
       }
       else {
-        bool IsConstant = false;
-        if (Value *V = DVC->GetValue(Br->getCondition())) {
-          if (ConstantInt *C = dyn_cast<ConstantInt>(V)) {
-            bool IsTrue = C->getLimitedValue() != 0;
-            BasicBlock *Succ = IsTrue ?
-              Br->getSuccessor(0) : Br->getSuccessor(1);
-            Add(Succ);
-            IsConstant = true;
+        if (ConstantInt *C = DVC->GetConstInt(Br->getCondition())) {
+          bool IsTrue = C->getLimitedValue() != 0;
+          BasicBlock *Succ = Br->getSuccessor(IsTrue ? 0 : 1);
+          BasicBlock *NotSucc = Br->getSuccessor(!IsTrue ? 0 : 1);
+
+          Add(Succ);
+
+          // Rewrite conditional branch as unconditional branch if
+          // we don't have structural information that needs it to
+          // be alive.
+          if (!Br->getMetadata(hlsl::DXIL::kDxBreakMDName)) {
+            BranchInst *NewBr = BranchInst::Create(Succ, BB);
+            hlsl::DxilMDHelper::CopyMetadata(*NewBr, *Br);
+            RemoveIncomingValueFrom(NotSucc, BB);
+
+            Br->eraseFromParent();
+            Br = nullptr;
+            Changed = true;
           }
         }
-        if (!IsConstant) {
+        else {
           Add(Br->getSuccessor(0));
           Add(Br->getSuccessor(1));
         }
@@ -72,7 +95,7 @@ static bool EraseDeadBlocks(Function &F, DxilValueCache *DVC) {
   }
 
   if (Seen.size() == F.size())
-    return false;
+    return Changed;
 
   std::vector<BasicBlock *> DeadBlocks;
 
@@ -87,7 +110,7 @@ static bool EraseDeadBlocks(Function &F, DxilValueCache *DVC) {
     // Make predecessors branch somewhere else and fix the phi nodes
     for (auto pred_it = pred_begin(BB); pred_it != pred_end(BB);) {
       BasicBlock *PredBB = *(pred_it++);
-      if (!Seen.count(PredBB))
+      if (!Seen.count(PredBB)) // Don't bother fixing it if it's gonna get deleted anyway
         continue;
       TerminatorInst *TI = PredBB->getTerminator();
       if (!TI) continue;
@@ -105,14 +128,8 @@ static bool EraseDeadBlocks(Function &F, DxilValueCache *DVC) {
     // Fix phi nodes in successors
     for (auto succ_it = succ_begin(BB); succ_it != succ_end(BB); succ_it++) {
       BasicBlock *SuccBB = *succ_it;
-      if (!Seen.count(SuccBB)) continue;
-      for (auto inst_it = SuccBB->begin(); inst_it != SuccBB->end();) {
-        Instruction *I = &*(inst_it++);
-        if (PHINode *PN = dyn_cast<PHINode>(I))
-          PN->removeIncomingValue(BB, true);
-        else
-          break;
-      }
+      if (!Seen.count(SuccBB)) continue; // Don't bother fixing it if it's gonna get deleted anyway
+      RemoveIncomingValueFrom(SuccBB, BB);
     }
 
     // Erase all instructions in block

+ 3 - 0
tools/clang/lib/CodeGen/CGHLSLMSFinishCodeGen.cpp

@@ -2653,6 +2653,9 @@ void AddDxBreak(Module &M, const SmallVector<llvm::BranchInst*, 16> &DxBreaks) {
     if (WaveUsers.count(BI->getParent()->getParent())) {
       CallInst *Call = CallInst::Create(FT, func, ArrayRef<Value *>(), "", BI);
       BI->setCondition(Call);
+      if (!BI->getMetadata(DXIL::kDxBreakMDName)) {
+        BI->setMetadata(DXIL::kDxBreakMDName, llvm::MDNode::get(BI->getContext(), {}));
+      }
     }
   }
 }

+ 1 - 1
tools/clang/test/HLSLFileCheck/dxil/debug/new_noop_no_fold_double.hlsl

@@ -34,7 +34,7 @@ float4 main() : SV_Target {
   // CHECK: load i32, i32*
   // CHECK-SAME: @dx.nothing
 
-  // CHECK: br i1
+  // CHECK: br
   if (w >= 0) {
     tex = tex1;
     // CHECK: load i32, i32*

+ 1 - 1
tools/clang/test/HLSLFileCheck/dxil/debug/new_noop_no_fold_int.hlsl

@@ -36,7 +36,7 @@ float4 main() : SV_Target {
   // CHECK: load i32, i32*
   // CHECK-SAME: @dx.nothing
 
-  // CHECK: br i1
+  // CHECK: br
   if (w >= 0) {
     tex = tex1;
     // CHECK: load i32, i32*

+ 1 - 1
tools/clang/test/HLSLFileCheck/dxil/debug/new_noops_no_fold.hlsl

@@ -34,7 +34,7 @@ float4 main() : SV_Target {
   // CHECK: load i32, i32*
   // CHECK-SAME: @dx.nothing
 
-  // CHECK: br i1
+  // CHECK: br
   if (w >= 0) {
     tex = tex1;
     // CHECK: load i32, i32*

+ 1 - 1
tools/clang/test/HLSLFileCheck/dxil/debug/new_noops_no_fold_vec.hlsl

@@ -41,7 +41,7 @@ float4 main() : SV_Target {
   // CHECK: load i32, i32*
   // CHECK-SAME: @dx.nothing
 
-  // CHECK: br i1
+  // CHECK: br
   if (foo.x+bar.y >= 0) {
     tex = tex1;
     // CHECK: load i32, i32*

+ 1 - 1
tools/clang/test/HLSLFileCheck/dxil/debug/no_fold.hlsl

@@ -18,7 +18,7 @@ float4 main() : SV_Target {
   // CHECK: fdiv
 
   Texture2D tex = tex0; 
-  // CHECK: br i1
+  // CHECK: br
   if (w >= 0) {
     tex = tex1;
     // CHECK: br

+ 1 - 1
tools/clang/test/HLSLFileCheck/dxil/debug/no_fold_vec.hlsl

@@ -23,7 +23,7 @@ float4 main() : SV_Target {
   // CHECK: fdiv
 
   Texture2D tex = tex0; 
-  // CHECK: br i1
+  // CHECK: br
   if (foo.x+bar.y >= 0) {
     tex = tex1;
     // CHECK: br

+ 1 - 1
tools/clang/test/HLSLFileCheck/dxil/debug/noop_no_fold_double.hlsl

@@ -30,7 +30,7 @@ float4 main() : SV_Target {
   // CHECK: load i32, i32*
   // CHECK-SAME: @dx.nothing
 
-  // CHECK: br i1
+  // CHECK: br
   if (w >= 0) {
     tex = tex1;
     // CHECK: load i32, i32*

+ 1 - 1
tools/clang/test/HLSLFileCheck/dxil/debug/noop_no_fold_int.hlsl

@@ -32,7 +32,7 @@ float4 main() : SV_Target {
   // CHECK: load i32, i32*
   // CHECK-SAME: @dx.nothing
 
-  // CHECK: br i1
+  // CHECK: br
   if (w >= 0) {
     tex = tex1;
     // CHECK: load i32, i32*

+ 1 - 1
tools/clang/test/HLSLFileCheck/dxil/debug/noops_no_fold.hlsl

@@ -30,7 +30,7 @@ float4 main() : SV_Target {
   // CHECK: load i32, i32*
   // CHECK-SAME: @dx.nothing
 
-  // CHECK: br i1
+  // CHECK: br
   if (w >= 0) {
     tex = tex1;
     // CHECK: load i32, i32*

+ 1 - 1
tools/clang/test/HLSLFileCheck/dxil/debug/noops_no_fold_vec.hlsl

@@ -37,7 +37,7 @@ float4 main() : SV_Target {
   // CHECK: load i32, i32*
   // CHECK-SAME: @dx.nothing
 
-  // CHECK: br i1
+  // CHECK: br
   if (foo.x+bar.y >= 0) {
     tex = tex1;
     // CHECK: load i32, i32*

+ 1 - 1
tools/clang/test/HLSLFileCheck/hlsl/intrinsics/wave/reduction/WaveBreakAndUnrollCS.hlsl

@@ -1,4 +1,4 @@
-// RUN: %dxc -T cs_6_0 %s | FileCheck %s
+// RUN: %dxc -T cs_6_0 -Od %s | FileCheck %s
 // A test of explicit loop unrolling on a loop that uses a wave op in a break block
 
 // CHECK: void @main