Browse Source

Made exit values not have to dominate latch in structurize loop exit (#3220)

Adam Yang 5 years ago
parent
commit
4b975c13f0

+ 286 - 197
lib/Transforms/Scalar/DxilRemoveUnstructuredLoopExits.cpp

@@ -111,10 +111,6 @@
 //     but the code will execute for one iteration even if the exit condition
 //     is met.
 //
-//   - If there are values used by exit_code that isn't defined in the 
-//     loop header (or anywhere that doesn't dominate the loop latch)
-//     this transformation does not take place.
-//
 // These limitations can be fixed in the future as needed.
 //
 //===----------------------------------------------------------------------===//
@@ -139,111 +135,81 @@
 
 using namespace llvm;
 
-static bool IsNoop(Instruction *inst) {
-  if (CallInst *ci = dyn_cast<CallInst>(inst)) {
-    if (Function *f = ci->getCalledFunction()) {
-      return f->getName() == hlsl::kNoopName;
-    }
-  }
-  return false;
-}
+static bool IsNoop(Instruction *inst);
 
-static BasicBlock *GetExitBlockForExitingBlock(Loop *L, BasicBlock *exiting_block) {
-  BranchInst *br = dyn_cast<BranchInst>(exiting_block->getTerminator());
-  assert(L->contains(exiting_block));
-  assert(br->isConditional());
-  BasicBlock *result = L->contains(br->getSuccessor(0)) ? br->getSuccessor(1) : br->getSuccessor(0);
-  assert(!L->contains(result));
-  return result;
-}
-
-static bool RemoveUnstructuredLoopExitsIteration(BasicBlock *exiting_block, Loop *L, LoopInfo *LI, DominatorTree *DT) {
+namespace {
 
-  LLVMContext &ctx = L->getHeader()->getContext();
-  Type *i1Ty = Type::getInt1Ty(ctx);
+struct Value_Info {
+  Value *val, *false_val;
+  PHINode *exit_phi;
+};
 
-  BasicBlock *exit_block = GetExitBlockForExitingBlock(L, exiting_block);
+struct Propagator {
+  DenseMap<std::pair<BasicBlock *, Value *>, PHINode *> cached_phis;
+  std::unordered_set<BasicBlock *> seen;
 
-  // If there's more than one predecessors for this exit block, don't risk it.
-  if (!exit_block->getSinglePredecessor())
-    return false;
+  // Get propagated value for val. It's guaranteed to be safe to use in bb.
+  Value *Get(Value *val, BasicBlock *bb) {
+    auto it = cached_phis.find({ bb, val });
+    if (it == cached_phis.end())
+      return nullptr;
 
-  {
-    BasicBlock *latch = L->getLoopLatch();
-    BasicBlock *latch_exit = GetExitBlockForExitingBlock(L, latch);
+    return it->second;
+  }
 
-    // If the latch and the exiting block go to the same place, then we probably already fixed this exit.
-    if (exit_block == latch_exit) {
-      return false;
+  void DeleteAllNewValues() {
+    for (auto &pair : cached_phis) {
+      pair.second->dropAllReferences();
     }
-
-    for (Instruction &I : *exit_block) {
-      if (PHINode *phi = dyn_cast<PHINode>(&I)) {
-        // If there are values flowing out of the loop into the exit_block,
-        // if any of those values do not dominate the latch, they would need
-        // to be propagated to the latch, which we don't do right now.
-        //
-        if (Instruction *value = dyn_cast<Instruction>(phi->getIncomingValueForBlock(exiting_block))) {
-          if (!DT->dominates(value, latch)) {
-            return false;
-          }
-        }
-      }
-      else {
-        break;
-      }
+    for (auto &pair : cached_phis) {
+      pair.second->eraseFromParent();
     }
+    cached_phis.clear();
   }
 
-  BranchInst *exiting_br = cast<BranchInst>(exiting_block->getTerminator());
-  Value *exit_cond = exiting_br->getCondition();
-
-  Value *exit_cond_dominates_latch = nullptr;
-  BasicBlock *new_exiting_block = nullptr;
-  SmallVector<std::pair<BasicBlock *, Value *>, 4> blocks_with_side_effect;
-  bool give_up = false;
-  std::unordered_map<BasicBlock *, PHINode *> cached_phis;
+  BasicBlock *Run(std::vector<Value_Info> &exit_values, BasicBlock *exiting_block, BasicBlock *latch, DominatorTree *DT, Loop *L, LoopInfo *LI, std::vector<BasicBlock *> &blocks_with_side_effect) {
+    BasicBlock *ret = RunImpl(exit_values, exiting_block, latch, DT, L, LI, blocks_with_side_effect);
+    // If we failed, remove all the values we added.
+    if (!ret) {
+      DeleteAllNewValues();
+    }
+    return ret;
+  }
 
-  // Use a worklist to propagate the exit condition from within the block
-  {
-    Value *false_value = ConstantInt::getFalse(i1Ty);
+  BasicBlock *RunImpl(std::vector<Value_Info> &exit_values, BasicBlock *exiting_block, BasicBlock *latch, DominatorTree *DT, Loop *L, LoopInfo *LI, std::vector<BasicBlock *> &blocks_with_side_effect) {
 
-    struct Propagate_Data {
+    struct Edge {
+      BasicBlock *prev;
       BasicBlock *bb;
-      Value *exit_cond;
     };
 
-    std::unordered_set<BasicBlock *> seen;
-    SmallVector<Propagate_Data, 4> work_list;
-
-    work_list.push_back({ exiting_block, exit_cond, });
+    BasicBlock *new_exiting_block = nullptr;
+    SmallVector<Edge, 4> work_list;
+    work_list.push_back({ nullptr, exiting_block });
     seen.insert(exiting_block);
 
-    BasicBlock *latch = L->getLoopLatch();
-
     for (unsigned i = 0; i < work_list.size(); i++) {
-      Propagate_Data data = work_list[i];
+      auto &edge = work_list[i];
+      BasicBlock *prev = edge.prev;
+      BasicBlock *bb = edge.bb;
 
-      BasicBlock *bb = data.bb;
-
-      // Don't continue to propagate when we hit the latch
-      if (bb == latch && DT->dominates(bb, latch)) {
-        exit_cond_dominates_latch = data.exit_cond;
+      // Don't continue to propagate when we hit the latch or dominate it.
+      if (DT->dominates(bb, latch)) {
         new_exiting_block = bb;
         continue;
       }
 
       // Do not include the exiting block itself in this calculation
-      if (i != 0) {
+      if (prev != nullptr) {
         // If this block is part of an inner loop... Give up for now.
-        if (LI->getLoopFor(data.bb) != L) {
-          give_up = true;
+        if (LI->getLoopFor(bb) != L) {
+          return nullptr;
         }
         // Otherwise just remember the blocks with side effects (including the latch)
         else {
           for (Instruction &I : *bb) {
             if (I.mayReadOrWriteMemory() && !IsNoop(&I)) {
-              blocks_with_side_effect.push_back({ bb, data.exit_cond });
+              blocks_with_side_effect.push_back(bb);
               break;
             }
           }
@@ -251,173 +217,296 @@ static bool RemoveUnstructuredLoopExitsIteration(BasicBlock *exiting_block, Loop
       } // If this is not the first iteration
 
       for (BasicBlock *succ : llvm::successors(bb)) {
+        // Don't propagate if block is not part of this loop.
         if (!L->contains(succ))
           continue;
 
-        PHINode *phi = cached_phis[succ];
-        if (!phi) {
-          phi = PHINode::Create(i1Ty, 2, "dx.struct_exit.exit_cond", &*succ->begin());
-          for (BasicBlock *pred : llvm::predecessors(succ)) {
-            phi->addIncoming(false_value, pred);
+        for (auto &pair : exit_values) {
+          // Find or create phi for the value in the successor block
+          PHINode *phi = cached_phis[{ succ, pair.val }];
+          if (!phi) {
+            phi = PHINode::Create(pair.false_val->getType(), 0, "dx.struct_exit.prop", &*succ->begin());
+            for (BasicBlock *pred : llvm::predecessors(succ)) {
+              phi->addIncoming(pair.false_val, pred);
+            }
+            cached_phis[{ succ, pair.val }] = phi;
           }
-          cached_phis[succ] = phi;
-        }
 
-        for (unsigned i = 0; i < phi->getNumIncomingValues(); i++) {
-          if (phi->getIncomingBlock(i) == bb) {
-            phi->setIncomingValue(i, data.exit_cond);
-            break;
+          // Find the incoming value for successor block
+          Value *incoming = nullptr;
+          if (!prev) {
+            incoming = pair.val;
+          }
+          else {
+            incoming = cached_phis[{ bb, pair.val }];
           }
-        }
 
-        if (!seen.count(succ)) {
-          work_list.push_back({ succ, phi, });
-          seen.insert(succ);
-        }
+          // Set incoming value for our phi
+          for (unsigned i = 0; i < phi->getNumIncomingValues(); i++) {
+            if (phi->getIncomingBlock(i) == bb) {
+              phi->setIncomingValue(i, incoming);
+            }
+          }
 
+          // Add to worklist
+          if (!seen.count(succ)) {
+            work_list.push_back({ bb, succ });
+            seen.insert(succ);
+          }
+        }
       } // for each succ
     } // for each in worklist
-  } // if exit condition is an instruction
 
-  if (give_up) {
-    for (std::pair<BasicBlock *, PHINode *> pair : cached_phis) {
-      if (pair.second)
-        pair.second->dropAllReferences();
+    if (new_exiting_block == exiting_block) {
+      return nullptr;
     }
-    for (std::pair<BasicBlock *, PHINode *> pair : cached_phis) {
-      if (pair.second)
-        pair.second->eraseFromParent();
+
+    return new_exiting_block;
+  }
+}; // struct Propagator
+
+} // Unnamed namespace
+
+static bool IsNoop(Instruction *inst) {
+  if (CallInst *ci = dyn_cast<CallInst>(inst)) {
+    if (Function *f = ci->getCalledFunction()) {
+      return f->getName() == hlsl::kNoopName;
     }
-    return false;
   }
+  return false;
+}
 
-  // Make the exiting block not exit.
-  {
-    BasicBlock *non_exiting_block = exiting_br->getSuccessor(exiting_br->getSuccessor(0) == exit_block ? 1 : 0);
-    BranchInst::Create(non_exiting_block, exiting_block);
-    exiting_br->eraseFromParent();
-    exiting_br = nullptr;
+static BasicBlock *GetExitBlockForExitingBlock(Loop *L, BasicBlock *exiting_block) {
+  BranchInst *br = dyn_cast<BranchInst>(exiting_block->getTerminator());
+  assert(L->contains(exiting_block));
+  assert(br->isConditional());
+  BasicBlock *result = L->contains(br->getSuccessor(0)) ? br->getSuccessor(1) : br->getSuccessor(0);
+  assert(!L->contains(result));
+  return result;
+}
+
+// Branch over the block's content with the condition cond.
+// All values used outside the block is replaced by a phi.
+//
+static void SkipBlockWithBranch(BasicBlock *bb, Value *cond, Loop *L, LoopInfo *LI) {
+  BasicBlock *body = bb->splitBasicBlock(bb->getFirstNonPHI());
+  body->setName("dx.struct_exit.cond_body");
+  BasicBlock *end = body->splitBasicBlock(body->getTerminator());
+  end->setName("dx.struct_exit.cond_end");
+
+  bb->getTerminator()->eraseFromParent();
+  BranchInst::Create(end, body, cond, bb);
+
+  for (Instruction &inst : *body) {
+    PHINode *phi = nullptr;
+
+    for (User *user : inst.users()) {
+      Instruction *user_inst = dyn_cast<Instruction>(user);
+      if (!user_inst)
+        continue;
+
+      if (user_inst->getParent() != body) {
+        if (!phi) {
+          phi = PHINode::Create(inst.getType(), 2, "", &*end->begin());
+          phi->addIncoming(UndefValue::get(inst.getType()), bb);
+          phi->addIncoming(&inst, body);
+        }
+
+        user_inst->replaceUsesOfWith(&inst, phi);
+      }
+    } // For each user of inst of body
+  } // For each inst in body
+
+  L->addBasicBlockToLoop(body, *LI);
+  L->addBasicBlockToLoop(end, *LI);
+}
+
+static unsigned GetNumPredecessors(BasicBlock *bb) {
+  unsigned ret = 0;
+  for (BasicBlock *pred : llvm::predecessors(bb)) {
+    (void)pred;
+    ret++;
   }
+  return ret;
+}
 
-  // If bb has side effect, split it into 3 basic blocks, where its body is
-  // gated behind if (!exit_cond)
-  for (std::pair<BasicBlock *, Value *> data : blocks_with_side_effect) {
-    BasicBlock *bb = data.first;
-    Value *exit_cond = data.second;
+static bool RemoveUnstructuredLoopExitsIteration(BasicBlock *exiting_block, Loop *L, LoopInfo *LI, DominatorTree *DT) {
 
-    BasicBlock *body = bb->splitBasicBlock(bb->getFirstNonPHI());
-    body->setName("dx.struct_exit.cond_body");
-    BasicBlock *end = body->splitBasicBlock(body->getTerminator());
-    end->setName("dx.struct_exit.cond_end");
+  LLVMContext &ctx = L->getHeader()->getContext();
+  Type *i1Ty = Type::getInt1Ty(ctx);
 
-    bb->getTerminator()->eraseFromParent();
-    BranchInst::Create(end, body, exit_cond, bb);
+  BasicBlock *exit_block = GetExitBlockForExitingBlock(L, exiting_block);
 
-    for (Instruction &inst : *body) {
-      PHINode *phi = nullptr;
+  BasicBlock *latch = L->getLoopLatch();
+  BasicBlock *latch_exit = GetExitBlockForExitingBlock(L, latch);
 
-      for (User *user : inst.users()) {
-        Instruction *user_inst = dyn_cast<Instruction>(user);
-        if (!user_inst)
-          continue;
+  // If exiting block already dominates latch, then no need to do anything.
+  if (DT->dominates(exiting_block, latch)) {
+    return false;
+  }
 
-        if (user_inst->getParent() != body) {
-          if (!phi) {
-            phi = PHINode::Create(inst.getType(), 2, "", &*end->begin());
-            phi->addIncoming(UndefValue::get(inst.getType()), bb);
-            phi->addIncoming(&inst, body);
-          }
+  Propagator prop;
 
-          user_inst->replaceUsesOfWith(&inst, phi);
+  BranchInst *exiting_br = cast<BranchInst>(exiting_block->getTerminator());
+  Value *exit_cond = exiting_br->getCondition();
+  BasicBlock *new_exiting_block = nullptr;
+
+  std::vector<Value_Info> exit_values;
+  std::vector<BasicBlock *> blocks_with_side_effect;
+
+  // Find the values that flow into the exit block from this loop.
+  {
+    // Look at the lcssa phi's in the exit block.
+    bool exit_cond_has_phi = false;
+    for (Instruction &I : *exit_block) {
+      if (PHINode *phi = dyn_cast<PHINode>(&I)) {
+        // If there are values flowing out of the loop into the exit_block,
+        // add them to the list to be propagated
+        Value *value = phi->getIncomingValueForBlock(exiting_block);
+        Value *false_value = nullptr;
+        if (value == exit_cond) {
+          false_value = ConstantInt::getFalse(i1Ty);
+          exit_cond_has_phi = true;
+        }
+        else {
+          false_value = UndefValue::get(value->getType());
         }
-      } // For each user of inst of body
-    } // For each inst in body
+        exit_values.push_back({ value, false_value, phi });
+      }
+      else {
+        break;
+      }
+    }
 
-    L->addBasicBlockToLoop(body, *LI);
-    L->addBasicBlockToLoop(end, *LI);
+    // If the exit condition is not among the exit phi's, add it.
+    if (!exit_cond_has_phi) {
+      exit_values.push_back({ exit_cond, ConstantInt::getFalse(i1Ty), nullptr });
+    }
+  }
 
-  } // For each bb with side effect
+  //
+  // Propagate those values we just found to a block that dominates the latch
+  //
+  new_exiting_block = prop.Run(exit_values, exiting_block, latch, DT, L, LI, blocks_with_side_effect);
 
-  assert(exit_cond_dominates_latch);
-  assert(new_exiting_block);
+  // Stop now if we failed
+  if (!new_exiting_block)
+    return false;
+
+  // If there are any blocks with side effects,
+  for (BasicBlock *bb : blocks_with_side_effect) {
+    Value *exit_cond_for_block = prop.Get(exit_cond, bb);
+    SkipBlockWithBranch(bb, exit_cond_for_block, L, LI);
+  }
+
+  // Make the exiting block not exit.
+  {
+    BasicBlock *non_exiting_block = exiting_br->getSuccessor(exiting_br->getSuccessor(0) == exit_block ? 1 : 0);
+    BranchInst::Create(non_exiting_block, exiting_block);
+    exiting_br->eraseFromParent();
+    exiting_br = nullptr;
+  }
+
+  Value *new_exit_cond = prop.Get(exit_cond, new_exiting_block);
+  assert(new_exit_cond);
 
   // Split the block where we're now exiting from, and branch to latch exit
-  BasicBlock *latch_exit = GetExitBlockForExitingBlock(L, L->getLoopLatch());
   StringRef old_name = new_exiting_block->getName();
   BasicBlock *new_not_exiting_block = new_exiting_block->splitBasicBlock(new_exiting_block->getFirstNonPHI());
   new_exiting_block->setName("dx.struct_exit.new_exiting");
   new_not_exiting_block->setName(old_name);
   L->addBasicBlockToLoop(new_not_exiting_block, *LI);
 
+  // Branch to latch_exit
   new_exiting_block->getTerminator()->eraseFromParent();
-  BranchInst::Create(latch_exit, new_not_exiting_block, exit_cond_dominates_latch, new_exiting_block);
-
-  // Split the latch exit, since it's going to branch to the real exit block
-  BasicBlock *post_exit_location = latch_exit->splitBasicBlock(latch_exit->getFirstNonPHI());
-  // If latch exit is part of an outer loop, add its split in there too.
-  if (Loop *outer_loop = LI->getLoopFor(latch_exit)) {
-    outer_loop->addBasicBlockToLoop(post_exit_location, *LI);
-  }
-  // If the original exit block is part of an outer loop, then latch exit (which is the
-  // new exit block) must be part of it, since all blocks that branch to within
-  // a loop must be part of that loop structure.
-  else if (Loop *outer_loop = LI->getLoopFor(exit_block)) {
-    outer_loop->addBasicBlockToLoop(latch_exit, *LI);
-  }
-
-  // Since now new exiting block is branching to latch exit, its phis need to be updated.
-  for (Instruction &inst : *latch_exit) {
-    PHINode *phi = dyn_cast<PHINode>(&inst);
-    if (!phi)
-      break;
-    phi->addIncoming(UndefValue::get(phi->getType()), new_exiting_block);
+  BranchInst::Create(latch_exit, new_not_exiting_block, new_exit_cond, new_exiting_block);
+
+  // If the exit block and the latch exit are the same, then we're already good.
+  // just update the phi nodes in the exit block.
+  if (latch_exit == exit_block) {
+    for (Value_Info &info : exit_values) {
+      // Take the phi node in the exit block and reset incoming block and value from latch_exit
+      PHINode *exit_phi = info.exit_phi;
+      if (exit_phi) {
+        for (unsigned i = 0; i < exit_phi->getNumIncomingValues(); i++) {
+          if (exit_phi->getIncomingBlock(i) == exiting_block) {
+            exit_phi->setIncomingBlock(i, new_exiting_block);
+            exit_phi->setIncomingValue(i, prop.Get(info.val, new_exiting_block));
+          }
+        }
+      }
+    }
   }
+  // Otherwise...
+  else {
 
-  unsigned latch_exit_num_predecessors = 0;
-  for (BasicBlock *pred : llvm::predecessors(latch_exit)) {
-    (void)pred;
-    latch_exit_num_predecessors++;
-  }
+    // 1. Split the latch exit, since it's going to branch to the real exit block
+    BasicBlock *post_exit_location = latch_exit->splitBasicBlock(latch_exit->getFirstNonPHI());
 
-  // Make exit condition visible
-  PHINode *exit_cond_lcssa = PHINode::Create(exit_cond_dominates_latch->getType(), latch_exit_num_predecessors, "dx.struct_exit.exit_cond_lcssa", latch_exit->begin());
-  for (BasicBlock *pred : llvm::predecessors(latch_exit)) {
-    if (pred == new_exiting_block) {
-      exit_cond_lcssa->addIncoming(exit_cond_dominates_latch, pred);
+    {
+      // If latch exit is part of an outer loop, add its split in there too.
+      if (Loop *outer_loop = LI->getLoopFor(latch_exit)) {
+        outer_loop->addBasicBlockToLoop(post_exit_location, *LI);
+      }
+      // If the original exit block is part of an outer loop, then latch exit (which is the
+      // new exit block) must be part of it, since all blocks that branch to within
+      // a loop must be part of that loop structure.
+      else if (Loop *outer_loop = LI->getLoopFor(exit_block)) {
+        outer_loop->addBasicBlockToLoop(latch_exit, *LI);
+      }
     }
-    else {
-      exit_cond_lcssa->addIncoming(ConstantInt::getFalse(exit_cond_lcssa->getType()), pred);
+
+    // 2. Add incoming values to latch_exit's phi nodes.
+    // Since now new exiting block is branching to latch exit, its phis need to be updated.
+    for (Instruction &inst : *latch_exit) {
+      PHINode *phi = dyn_cast<PHINode>(&inst);
+      if (!phi)
+        break;
+      phi->addIncoming(UndefValue::get(phi->getType()), new_exiting_block);
     }
-  }
 
-  // Take the exit outside the loop.
-  latch_exit->getTerminator()->eraseFromParent();
-  BranchInst::Create(exit_block, post_exit_location, exit_cond_lcssa, latch_exit);
 
-  // Fix the phi's in the real exit block, and insert phis in the latch exit to maintain
-  // lcssa form.
-  for (Instruction &inst : *exit_block) {
-    PHINode *phi = dyn_cast<PHINode>(&inst);
-    if (!phi)
-      break;
+    unsigned latch_exit_num_predecessors = GetNumPredecessors(latch_exit);
+    PHINode *exit_cond_lcssa = nullptr;
+    for (Value_Info &info : exit_values) {
 
-    for (unsigned i = 0; i < phi->getNumIncomingValues(); i++) {
-      if (phi->getIncomingBlock(i) == exiting_block) {
-        phi->setIncomingBlock(i, latch_exit);
+      // 3. Create lcssa phi's for all the propagated values at latch_exit.
+      // Make exit values visible in the latch_exit
+      PHINode *val_lcssa = PHINode::Create(info.val->getType(), latch_exit_num_predecessors, "dx.struct_exit.val_lcssa", latch_exit->begin());
 
-        PHINode *lcssa_phi = PHINode::Create(phi->getType(), latch_exit_num_predecessors, "dx.struct_exit.lcssa_phi", latch_exit->begin());
-        for (BasicBlock *pred : llvm::predecessors(latch_exit)) {
-          if (pred == new_exiting_block) {
-            lcssa_phi->addIncoming(phi->getIncomingValue(i), new_exiting_block);
-          }
-          else {
-            lcssa_phi->addIncoming(UndefValue::get(lcssa_phi->getType()), pred);
-          }
+      if (info.val == exit_cond) {
+        // Record the phi for the exit condition
+        exit_cond_lcssa = val_lcssa;
+        exit_cond_lcssa->setName("dx.struct_exit.exit_cond_lcssa");
+      }
+
+      for (BasicBlock *pred : llvm::predecessors(latch_exit)) {
+        if (pred == new_exiting_block) {
+          Value *incoming = prop.Get(info.val, new_exiting_block);
+          assert(incoming);
+          val_lcssa->addIncoming(incoming, pred);
+        }
+        else {
+          val_lcssa->addIncoming(info.false_val, pred);
         }
+      }
 
-        phi->setIncomingValue(i, lcssa_phi);
+      // 4. Update the phis in the exit_block to use the lcssa phi's we just created.
+      PHINode *exit_phi = info.exit_phi;
+      if (exit_phi) {
+        for (unsigned i = 0; i < exit_phi->getNumIncomingValues(); i++) {
+          if (exit_phi->getIncomingBlock(i) == exiting_block) {
+            exit_phi->setIncomingBlock(i, latch_exit);
+            exit_phi->setIncomingValue(i, val_lcssa);
+          }
+        }
       }
     }
+
+    // 5. Take the first half of latch_exit and branch it to the exit_block based
+    // on the propagated exit condition.
+    latch_exit->getTerminator()->eraseFromParent();
+    BranchInst::Create(exit_block, post_exit_location, exit_cond_lcssa, latch_exit);
   }
 
   DT->recalculate(*L->getHeader()->getParent());

+ 15 - 8
tools/clang/test/HLSLFileCheck/hlsl/control_flow/loops/struct_exit_exit_value.hlsl

@@ -1,13 +1,22 @@
-// RUN: %dxc -Zi -E main -O3 -T ps_6_0 -opt-enable structurize-loop-exits-for-unroll %s | FileCheck %s
 // RUN: %dxc -Zi -E main -Od -T ps_6_0 -opt-enable structurize-loop-exits-for-unroll %s -DFORCE_UNROLL | FileCheck %s
 // RUN: %dxc -Zi -E main -T ps_6_0 -opt-enable structurize-loop-exits-for-unroll %s -DFORCE_UNROLL | FileCheck %s
 
+// note: not testing the path without [unroll]. Can't manage to make a loop small enough for the optimizer to want to unroll.
+
 // CHECK: %{{.+}} = call float @dx.op.unary.f32(i32 13
+
+// CHECK: dx.struct_exit.cond_body
+// CHECK: call void @dx.op.textureStore.f32(i32 67
+
 // CHECK: %{{.+}} = call float @dx.op.unary.f32(i32 13
+
+// CHECK: dx.struct_exit.cond_body
+// CHECK: call void @dx.op.textureStore.f32(i32 67
+
 // CHECK: %{{.+}} = call float @dx.op.unary.f32(i32 13
 
-// Make sure we didn't transform
-// CHECK-NOT: dx.struct_exit
+// CHECK: dx.struct_exit.cond_body
+// CHECK: call void @dx.op.textureStore.f32(i32 67
 
 #ifdef FORCE_UNROLL
 #define UNROLL [unroll]
@@ -34,12 +43,9 @@ float main(uint a : A, uint b : B, uint c : C) : SV_Target {
 
       if ((a * i) & b) {
 
-        int offset = 0; // This value doesn't dominate latch, is loop dependent,
+        int offset = i; // This value doesn't dominate latch, is loop dependent,
                         // and therefore must be propagated through to loop latch
                         // so the hoisted loop exit can use it.
-                        //
-                        // We don't do this right now, so the transformation shouldn't
-                        // happen.
         if (i % 2 == 0) {
           offset = 1;
         }
@@ -50,8 +56,9 @@ float main(uint a : A, uint b : B, uint c : C) : SV_Target {
           return 1;
         }
 
-        array[(idx + i) % 5] += a;
+        uav1[i + offset] += a;
       }
+      uav1[i] += a;
     }
   }