浏览代码

Made nested unroll work. (#3436)

Adam Yang 4 年之前
父节点
当前提交
f5f4b1aa35

+ 131 - 58
lib/Transforms/Scalar/DxilLoopUnroll.cpp

@@ -91,32 +91,22 @@
 using namespace llvm;
 using namespace hlsl;
 
-// Copied over from LoopUnroll.cpp - RemapInstruction()
-static inline void RemapInstruction(Instruction *I,
-                                    ValueToValueMapTy &VMap) {
-  for (unsigned op = 0, E = I->getNumOperands(); op != E; ++op) {
-    Value *Op = I->getOperand(op);
-    ValueToValueMapTy::iterator It = VMap.find(Op);
-    if (It != VMap.end())
-      I->setOperand(op, It->second);
-  }
-
-  if (PHINode *PN = dyn_cast<PHINode>(I)) {
-    for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
-      ValueToValueMapTy::iterator It = VMap.find(PN->getIncomingBlock(i));
-      if (It != VMap.end())
-        PN->setIncomingBlock(i, cast<BasicBlock>(It->second));
-    }
-  }
-}
-
-
 namespace {
 
+struct ClonedIteration {
+  SmallVector<BasicBlock *, 16> Body;
+  BasicBlock *Latch = nullptr;
+  BasicBlock *Header = nullptr;
+  ValueToValueMapTy VarMap;
+  SetVector<BasicBlock *> Extended; // Blocks that are included in the clone that are not in the core loop body.
+  ClonedIteration() {}
+};
+
 class DxilLoopUnroll : public LoopPass {
 public:
   static char ID;
 
+  std::set<Loop *> LoopsThatFailed;
   std::unordered_set<Function *> CleanedUpAlloca;
   unsigned MaxIterationAttempt = 0;
   bool OnlyWarnOnFail = false;
@@ -132,6 +122,7 @@ public:
   }
   const char *getPassName() const override { return "Dxil Loop Unroll"; }
   bool runOnLoop(Loop *L, LPPassManager &LPM) override;
+  bool doFinalization() override;
   bool IsLoopSafeToClone(Loop *L);
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     AU.addRequired<LoopInfoWrapperPass>();
@@ -154,9 +145,29 @@ public:
     OS << ",MaxIterationAttempt=" << MaxIterationAttempt;
     OS << ",OnlyWarnOnFail=" << OnlyWarnOnFail;
   }
-
+  void RecursivelyRemoveLoopOnSuccess(LPPassManager &LPM, Loop *L);
+  void RecursivelyRecreateSubLoopForIteration(LPPassManager &LPM, LoopInfo *LI, Loop *OuterL, Loop *L, ClonedIteration &Iter, unsigned Depth=0);
 };
 
+// Copied over from LoopUnroll.cpp - RemapInstruction()
+static inline void DxilLoopUnrollRemapInstruction(Instruction *I,
+                                    ValueToValueMapTy &VMap) {
+  for (unsigned op = 0, E = I->getNumOperands(); op != E; ++op) {
+    Value *Op = I->getOperand(op);
+    ValueToValueMapTy::iterator It = VMap.find(Op);
+    if (It != VMap.end())
+      I->setOperand(op, It->second);
+  }
+
+  if (PHINode *PN = dyn_cast<PHINode>(I)) {
+    for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
+      ValueToValueMapTy::iterator It = VMap.find(PN->getIncomingBlock(i));
+      if (It != VMap.end())
+        PN->setIncomingBlock(i, cast<BasicBlock>(It->second));
+    }
+  }
+}
+
 char DxilLoopUnroll::ID;
 
 static void FailLoopUnroll(bool WarnOnly, Function *F, DebugLoc DL, const Twine &Message) {
@@ -167,15 +178,6 @@ static void FailLoopUnroll(bool WarnOnly, Function *F, DebugLoc DL, const Twine
   Ctx.diagnose(DiagnosticInfoDxil(F, DL.get(), Message, severity));
 }
 
-struct LoopIteration {
-  SmallVector<BasicBlock *, 16> Body;
-  BasicBlock *Latch = nullptr;
-  BasicBlock *Header = nullptr;
-  ValueToValueMapTy VarMap;
-  SetVector<BasicBlock *> Extended; // Blocks that are included in the clone that are not in the core loop body.
-  LoopIteration() {}
-};
-
 static bool GetConstantI1(Value *V, bool *Val=nullptr) {
   if (ConstantInt *C = dyn_cast<ConstantInt>(V)) {
     if (V->getType()->isIntegerTy(1)) {
@@ -633,16 +635,22 @@ static bool BreakUpArrayAllocas(bool AllowOOBIndex, IteratorT ItBegin, IteratorT
   return Success;
 }
 
-static void RecursivelyRemoveLoopFromQueue(LPPassManager &LPM, Loop *L) {
+void DxilLoopUnroll::RecursivelyRemoveLoopOnSuccess(LPPassManager &LPM, Loop *L) {
   // Copy the sub loops into a separate list because
   // the original list may change.
   SmallVector<Loop *, 4> SubLoops(L->getSubLoops().begin(), L->getSubLoops().end());
 
   // Must remove all child loops first.
   for (Loop *SubL : SubLoops) {
-    RecursivelyRemoveLoopFromQueue(LPM, SubL);
+    RecursivelyRemoveLoopOnSuccess(LPM, SubL);
   }
 
+  // Remove any loops/subloops that failed because we are about to
+  // delete them. This will not prevent them from being retried because
+  // they would have been recreated for each cloned iteration.
+  LoopsThatFailed.erase(L);
+
+  // Loop is done and about to be deleted, remove it from queue.
   LPM.deleteLoopFromQueue(L);
 }
 
@@ -673,6 +681,49 @@ bool DxilLoopUnroll::IsLoopSafeToClone(Loop *L) {
   return true;
 }
 
+void DxilLoopUnroll::RecursivelyRecreateSubLoopForIteration(LPPassManager &LPM, LoopInfo *LI, Loop *OuterL, Loop *L, ClonedIteration &Iter, unsigned Depth) {
+  Loop *NewL = new Loop();
+
+  // Insert it to queue in a depth first way, otherwise `insertLoopIntoQueue`
+  // inserts adds parent first.
+  LPM.insertLoopIntoQueue(NewL);
+  if (OuterL) {
+    OuterL->addChildLoop(NewL);
+  }
+  else {
+    LI->addTopLevelLoop(NewL);
+  }
+
+  // First add all the blocks. It's important that we first add them here first
+  // (Instead of letting the recursive call do the job), since it's important that
+  // the loop header is added FIRST.
+  for (auto it = L->block_begin(), end = L->block_end(); it != end; it++) {
+    BasicBlock *OriginalBB = *it;
+    BasicBlock *NewBB = cast<BasicBlock>(Iter.VarMap[OriginalBB]);
+
+    // Manually call addBlockEntry instead of addBasicBlockToLoop because 
+    // addBasicBlockToLoop also checks and sets the BB -> Loop mapping.
+    NewL->addBlockEntry(NewBB);
+    LI->changeLoopFor(NewBB, NewL);
+
+    // Now check if the block has been added to outer loops already. This is
+    // only necessary for the first depth of this call.
+    if (Depth == 0) {
+      Loop *OuterL_it = OuterL;
+      while (OuterL_it) {
+        OuterL_it->addBlockEntry(NewBB);
+        OuterL_it = OuterL_it->getParentLoop();
+      }
+    }
+  }
+
+  // Construct any sub-loops that exist. The BB -> Loop mapping in LI will be
+  // rewritten to the sub-loop as needed.
+  for (Loop *SubL : L->getSubLoops()) {
+    RecursivelyRecreateSubLoopForIteration(LPM, LI, NewL, SubL, Iter, Depth+1);
+  }
+}
+
 bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
 
   DebugLoc LoopLoc = L->getStartLoc(); // Debug location for the start of the loop.
@@ -843,7 +894,7 @@ bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
   // Re-establish LCSSA form to get ready for unrolling.
   CreateLCSSA(ToBeCloned, NewExits, L, *DT, LI);
 
-  SmallVector<std::unique_ptr<LoopIteration>, 16> Iterations; // List of cloned iterations
+  SmallVector<std::unique_ptr<ClonedIteration>, 16> Iterations; // List of cloned iterations
   bool Succeeded = false;
 
   unsigned MaxAttempt = this->MaxIterationAttempt;
@@ -858,11 +909,11 @@ bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
 
   for (unsigned IterationI = 0; IterationI < MaxAttempt; IterationI++) {
 
-    LoopIteration *PrevIteration = nullptr;
+    ClonedIteration *PrevIteration = nullptr;
     if (Iterations.size())
       PrevIteration = Iterations.back().get();
-    Iterations.push_back(llvm::make_unique<LoopIteration>());
-    LoopIteration &CurIteration = *Iterations.back().get();
+    Iterations.push_back(llvm::make_unique<ClonedIteration>());
+    ClonedIteration &CurIteration = *Iterations.back().get();
 
     // Clone the blocks.
     for (BasicBlock *BB : ToBeCloned) {
@@ -913,7 +964,7 @@ bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
     // Remap the instructions inside of cloned blocks.
     for (BasicBlock *BB : CurIteration.Body) {
       for (Instruction &I : *BB) {
-        ::RemapInstruction(&I, CurIteration.VarMap);
+        DxilLoopUnrollRemapInstruction(&I, CurIteration.VarMap);
       }
     }
 
@@ -995,13 +1046,20 @@ bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
   }
 
   if (Succeeded) {
+    // Now that we successfully unrolled the loop L, if there were any sub loops in L,
+    // we have to recreate all the sub-loops for each iteration of L that we cloned.
+    for (std::unique_ptr<ClonedIteration> &IterPtr : Iterations) {
+      for (Loop *SubL : L->getSubLoops())
+        RecursivelyRecreateSubLoopForIteration(LPM, LI, OuterL, SubL, *IterPtr);
+    }
+
     // We are going to be cleaning them up later. Maker sure
     // they're in entry block so deleting loop blocks don't 
     // kill them too.
     for (AllocaInst *AI : ProblemAllocas)
       DXASSERT_LOCALVAR(AI, AI->getParent() == &F->getEntryBlock(), "Alloca is not in entry block.");
 
-    LoopIteration &FirstIteration = *Iterations.front().get();
+    ClonedIteration &FirstIteration = *Iterations.front().get();
     // Make the predecessor branch to the first new header.
     {
       BranchInst *BI = cast<BranchInst>(Predecessor->getTerminator());
@@ -1016,9 +1074,11 @@ bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
 
       // Core body blocks need to be added to outer loop
       for (size_t i = 0; i < Iterations.size(); i++) {
-        LoopIteration &Iteration = *Iterations[i].get();
+        ClonedIteration &Iteration = *Iterations[i].get();
         for (BasicBlock *BB : Iteration.Body) {
-          if (!Iteration.Extended.count(BB)) {
+          if (!Iteration.Extended.count(BB) &&
+            !OuterL->contains(BB))
+          {
             OuterL->addBasicBlockToLoop(BB, *LI);
           }
         }
@@ -1032,7 +1092,7 @@ bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
 
       // Cloned exit blocks may need to be added to outer loop
       for (size_t i = 0; i < Iterations.size(); i++) {
-        LoopIteration &Iteration = *Iterations[i].get();
+        ClonedIteration &Iteration = *Iterations[i].get();
         for (BasicBlock *BB : Iteration.Extended) {
           if (HasSuccessorsInLoop(BB, OuterL))
             OuterL->addBasicBlockToLoop(BB, *LI);
@@ -1047,7 +1107,7 @@ bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
       LI->removeBlock(BB);
 
     // Remove loop and all child loops from queue.
-    RecursivelyRemoveLoopFromQueue(LPM, L);
+    RecursivelyRemoveLoopOnSuccess(LPM, L);
 
     // Remove dead blocks.
     for (BasicBlock *BB : ToBeCloned)
@@ -1079,30 +1139,22 @@ bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
 
   // If we were unsuccessful in unrolling the loop
   else {
-    const char *Msg =
-        "Could not unroll loop. Loop bound could not be deduced at compile time. "
-        "Use [unroll(n)] to give an explicit count.";
-    if (OnlyWarnOnFail) {
-      FailLoopUnroll(true /*warn only*/, F, LoopLoc, Msg);
-    }
-    else {
-      FailLoopUnroll(false /*warn only*/, F, LoopLoc,
-        Twine(Msg) + Twine(" Use '-HV 2016' to treat this as warning."));
-    }
+    // Mark loop as failed.
+    LoopsThatFailed.insert(L);
 
     // Remove all the cloned blocks
-    for (std::unique_ptr<LoopIteration> &Ptr : Iterations) {
-      LoopIteration &Iteration = *Ptr.get();
+    for (std::unique_ptr<ClonedIteration> &Ptr : Iterations) {
+      ClonedIteration &Iteration = *Ptr.get();
       for (BasicBlock *BB : Iteration.Body)
         DetachFromSuccessors(BB);
     }
-    for (std::unique_ptr<LoopIteration> &Ptr : Iterations) {
-      LoopIteration &Iteration = *Ptr.get();
+    for (std::unique_ptr<ClonedIteration> &Ptr : Iterations) {
+      ClonedIteration &Iteration = *Ptr.get();
       for (BasicBlock *BB : Iteration.Body)
         BB->dropAllReferences();
     }
-    for (std::unique_ptr<LoopIteration> &Ptr : Iterations) {
-      LoopIteration &Iteration = *Ptr.get();
+    for (std::unique_ptr<ClonedIteration> &Ptr : Iterations) {
+      ClonedIteration &Iteration = *Ptr.get();
       for (BasicBlock *BB : Iteration.Body)
         BB->eraseFromParent();
     }
@@ -1111,6 +1163,27 @@ bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
   }
 }
 
+bool DxilLoopUnroll::doFinalization() {
+  const char *Msg =
+      "Could not unroll loop. Loop bound could not be deduced at compile time. "
+      "Use [unroll(n)] to give an explicit count.";
+
+  if (LoopsThatFailed.size()) {
+    for (Loop *L : LoopsThatFailed) {
+      Function *F = L->getHeader()->getParent();
+      DebugLoc LoopLoc = L->getStartLoc(); // Debug location for the start of the loop.
+      if (OnlyWarnOnFail) {
+        FailLoopUnroll(true /*warn only*/, F, LoopLoc, Msg);
+      }
+      else {
+        FailLoopUnroll(false /*warn only*/, F, LoopLoc,
+          Twine(Msg) + Twine(" Use '-HV 2016' to treat this as warning."));
+      }
+    }
+  }
+  return false;
+}
+
 }
 
 Pass *llvm::createDxilLoopUnrollPass(unsigned MaxIterationAttempt, bool OnlyWarnOnFail, bool StructurizeLoopExits) {

+ 31 - 0
tools/clang/test/HLSLFileCheck/hlsl/control_flow/loops/nested_unroll.hlsl

@@ -0,0 +1,31 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+// RUN: %dxc -E main -T ps_6_0 %s -Od | FileCheck %s
+
+// CHECK: @main
+
+// i == 0
+// CHECK: call float @dx.op.unary.f32(i32 12
+
+// i == 1
+// CHECK: call float @dx.op.unary.f32(i32 12
+// CHECK: call float @dx.op.unary.f32(i32 12
+
+// i == 2
+// CHECK: call float @dx.op.unary.f32(i32 12
+// CHECK: call float @dx.op.unary.f32(i32 12
+
+// CHECK-NOT: call float @dx.op.unary.f32(i32 13
+
+[RootSignature("")]
+float main(float foo : FOO) : SV_Target {
+  float result = 0;
+  [unroll]
+  for (uint i = 0; i < 3; i++) {
+    [unroll]
+    for (uint j = 0; j <= i; j++) {
+      result += cos(j + i + foo);
+    }
+  }
+  return result;
+}
+

+ 44 - 0
tools/clang/test/HLSLFileCheck/hlsl/control_flow/loops/nested_unroll2.hlsl

@@ -0,0 +1,44 @@
+// RUN: %dxc -E main -T ps_6_0 %s -Od | FileCheck %s
+
+// CHECK: @main
+
+// i == 0 && j == 0
+// CHECK: call float @dx.op.unary.f32(i32 12
+
+// i == 1 && j == 0
+// CHECK: call float @dx.op.unary.f32(i32 12
+
+// i == 1 && j == 1
+// CHECK: call float @dx.op.unary.f32(i32 12
+// CHECK: call float @dx.op.unary.f32(i32 12
+
+// i == 2 && j == 0
+// CHECK: call float @dx.op.unary.f32(i32 12
+
+// i == 2 && j == 1
+// CHECK: call float @dx.op.unary.f32(i32 12
+// CHECK: call float @dx.op.unary.f32(i32 12
+
+// i == 2 && j == 2
+// CHECK: call float @dx.op.unary.f32(i32 12
+// CHECK: call float @dx.op.unary.f32(i32 12
+// CHECK: call float @dx.op.unary.f32(i32 12
+
+// CHECK-NOT: call float @dx.op.unary.f32(i32 13
+
+[RootSignature("")]
+float main(float foo : FOO, float bar : BAR) : SV_Target {
+  float result = 0;
+  [unroll]
+  for (uint i = 0; i < 3; i++) {
+    [unroll]
+    for (uint j = 0; j <= i; j++) {
+      [unroll]
+      for (uint k = 0; k <= j; k++) {
+        result += cos((k + j + i + foo) * bar);
+      }
+    }
+  }
+  return result;
+}
+

+ 25 - 0
tools/clang/test/HLSLFileCheck/hlsl/control_flow/loops/nested_unroll_fail.hlsl

@@ -0,0 +1,25 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// CHECK: error: Could not unroll loop.
+
+[RootSignature("")]
+float main(float foo : FOO) : SV_Target {
+  float result = 0;
+  [unroll]
+  for (uint k = 0; k < 3; k++) {
+
+    [loop]
+    for (uint i = 0; i < 3; i++) {
+
+      // This will fail, since the middle loop is not
+      // unrolled.
+      [unroll]
+      for (uint j = 0; j <= i; j++) {
+        result += cos(k + j + i + foo);
+      }
+    }
+  }
+
+  return result;
+}
+