DxilRemoveUnstructuredLoopExits.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581
  1. //===- DxilRemoveUnstructuredLoopExits.cpp - Make unrolled loops structured ---===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. //===----------------------------------------------------------------------===//
  10. //
  11. // Loops that look like the following when unrolled becomes unstructured:
  12. //
  13. // for(;;) {
  14. // if (a) {
  15. // if (b) {
  16. // exit_code_0;
  17. // break; // Unstructured loop exit
  18. // }
  19. //
  20. // code_0;
  21. //
  22. // if (c) {
  23. // if (d) {
  24. // exit_code_1;
  25. // break; // Unstructured loop exit
  26. // }
  27. // code_1;
  28. // }
  29. //
  30. // code_2;
  31. //
  32. // ...
  33. // }
  34. //
  35. // code_3;
  36. //
  37. // if (exit)
  38. // break;
  39. // }
  40. //
  41. //
  42. // This pass transforms the loop into the following form:
  43. //
  44. // bool broke_0 = false;
  45. // bool broke_1 = false;
  46. //
  47. // for(;;) {
  48. // if (a) {
  49. // if (b) {
  50. // broke_0 = true; // Break flag
  51. // }
  52. //
  53. // if (!broke_0) {
  54. // code_0;
  55. // }
  56. //
  57. // if (!broke_0) {
  58. // if (c) {
  59. // if (d) {
  60. // broke_1 = true; // Break flag
  61. // }
  62. // if (!broke_1) {
  63. // code_1;
  64. // }
  65. // }
  66. //
  67. // if (!broke_1) {
  68. // code_2;
  69. // }
  70. // }
  71. //
  72. // ...
  73. // }
  74. //
  75. // if (!broke_0) {
  76. // break;
  77. // }
  78. //
  79. // if (!broke_1) {
  80. // break;
  81. // }
  82. //
  83. // code_3;
  84. //
  85. // if (exit)
  86. // break;
  87. // }
  88. //
  89. // if (broke_0) {
  90. // exit_code_0;
  91. // }
  92. //
  93. // if (broke_1) {
  94. // exit_code_1;
  95. // }
  96. //
  97. // Essentially it hoists the exit branch out of the loop.
  98. //
  99. // This function should be called any time before a function is unrolled to
  100. // avoid generating unstructured code.
  101. //
  102. // There are several limitations at the moment:
  103. //
  104. // - if code_0, code_1, etc has any loops in there, this transform
  105. // does not take place. Since the values that flow out of the conditions
  106. // are phi of undef, I do not want to risk the loops not exiting.
  107. //
  108. // - code_0, code_1, etc, become conditional only when there are
  109. // side effects in there. This doesn't impact code correctness,
  110. // but the code will execute for one iteration even if the exit condition
  111. // is met.
  112. //
  113. // These limitations can be fixed in the future as needed.
  114. //
  115. //===----------------------------------------------------------------------===//
  116. #include "llvm/Analysis/LoopPass.h"
  117. #include "llvm/Analysis/AssumptionCache.h"
  118. #include "llvm/Transforms/Scalar.h"
  119. #include "llvm/Transforms/Utils/Local.h"
  120. #include "llvm/Transforms/Utils/LoopUtils.h"
  121. #include "llvm/IR/Instructions.h"
  122. #include "llvm/IR/Verifier.h"
  123. #include "llvm/IR/IntrinsicInst.h"
  124. #include "llvm/Support/raw_ostream.h"
  125. #include "llvm/Support/Debug.h"
  126. #include "llvm/ADT/SetVector.h"
  127. #include "dxc/HLSL/DxilNoops.h"
  128. #include <unordered_map>
  129. #include <unordered_set>
  130. #include "DxilRemoveUnstructuredLoopExits.h"
  131. using namespace llvm;
  132. static bool IsNoop(Instruction *inst);
  133. namespace {
  134. struct Value_Info {
  135. Value *val, *false_val;
  136. PHINode *exit_phi;
  137. };
  138. struct Propagator {
  139. DenseMap<std::pair<BasicBlock *, Value *>, PHINode *> cached_phis;
  140. std::unordered_set<BasicBlock *> seen;
  141. // Get propagated value for val. It's guaranteed to be safe to use in bb.
  142. Value *Get(Value *val, BasicBlock *bb) {
  143. auto it = cached_phis.find({ bb, val });
  144. if (it == cached_phis.end())
  145. return nullptr;
  146. return it->second;
  147. }
  148. void DeleteAllNewValues() {
  149. for (auto &pair : cached_phis) {
  150. pair.second->dropAllReferences();
  151. }
  152. for (auto &pair : cached_phis) {
  153. pair.second->eraseFromParent();
  154. }
  155. cached_phis.clear();
  156. }
  157. BasicBlock *Run(std::vector<Value_Info> &exit_values, BasicBlock *exiting_block, BasicBlock *latch, DominatorTree *DT, Loop *L, LoopInfo *LI, std::vector<BasicBlock *> &blocks_with_side_effect) {
  158. BasicBlock *ret = RunImpl(exit_values, exiting_block, latch, DT, L, LI, blocks_with_side_effect);
  159. // If we failed, remove all the values we added.
  160. if (!ret) {
  161. DeleteAllNewValues();
  162. }
  163. return ret;
  164. }
  165. BasicBlock *RunImpl(std::vector<Value_Info> &exit_values, BasicBlock *exiting_block, BasicBlock *latch, DominatorTree *DT, Loop *L, LoopInfo *LI, std::vector<BasicBlock *> &blocks_with_side_effect) {
  166. struct Edge {
  167. BasicBlock *prev;
  168. BasicBlock *bb;
  169. };
  170. BasicBlock *new_exiting_block = nullptr;
  171. SmallVector<Edge, 4> work_list;
  172. work_list.push_back({ nullptr, exiting_block });
  173. seen.insert(exiting_block);
  174. for (unsigned i = 0; i < work_list.size(); i++) {
  175. auto &edge = work_list[i];
  176. BasicBlock *prev = edge.prev;
  177. BasicBlock *bb = edge.bb;
  178. // Don't continue to propagate when we hit the latch or dominate it.
  179. if (DT->dominates(bb, latch)) {
  180. new_exiting_block = bb;
  181. continue;
  182. }
  183. // Do not include the exiting block itself in this calculation
  184. if (prev != nullptr) {
  185. // If this block is part of an inner loop... Give up for now.
  186. if (LI->getLoopFor(bb) != L) {
  187. return nullptr;
  188. }
  189. // Otherwise just remember the blocks with side effects (including the latch)
  190. else {
  191. for (Instruction &I : *bb) {
  192. if (I.mayReadOrWriteMemory() && !IsNoop(&I)) {
  193. blocks_with_side_effect.push_back(bb);
  194. break;
  195. }
  196. }
  197. }
  198. } // If this is not the first iteration
  199. for (BasicBlock *succ : llvm::successors(bb)) {
  200. // Don't propagate if block is not part of this loop.
  201. if (!L->contains(succ))
  202. continue;
  203. for (auto &pair : exit_values) {
  204. // Find or create phi for the value in the successor block
  205. PHINode *phi = cached_phis[{ succ, pair.val }];
  206. if (!phi) {
  207. phi = PHINode::Create(pair.false_val->getType(), 0, "dx.struct_exit.prop", &*succ->begin());
  208. for (BasicBlock *pred : llvm::predecessors(succ)) {
  209. phi->addIncoming(pair.false_val, pred);
  210. }
  211. cached_phis[{ succ, pair.val }] = phi;
  212. }
  213. // Find the incoming value for successor block
  214. Value *incoming = nullptr;
  215. if (!prev) {
  216. incoming = pair.val;
  217. }
  218. else {
  219. incoming = cached_phis[{ bb, pair.val }];
  220. }
  221. // Set incoming value for our phi
  222. for (unsigned i = 0; i < phi->getNumIncomingValues(); i++) {
  223. if (phi->getIncomingBlock(i) == bb) {
  224. phi->setIncomingValue(i, incoming);
  225. }
  226. }
  227. // Add to worklist
  228. if (!seen.count(succ)) {
  229. work_list.push_back({ bb, succ });
  230. seen.insert(succ);
  231. }
  232. }
  233. } // for each succ
  234. } // for each in worklist
  235. if (new_exiting_block == exiting_block) {
  236. return nullptr;
  237. }
  238. return new_exiting_block;
  239. }
  240. }; // struct Propagator
  241. } // Unnamed namespace
  242. static bool IsNoop(Instruction *inst) {
  243. if (CallInst *ci = dyn_cast<CallInst>(inst)) {
  244. if (Function *f = ci->getCalledFunction()) {
  245. return f->getName() == hlsl::kNoopName;
  246. }
  247. }
  248. return false;
  249. }
  250. static Value* GetDefaultValue(Type *type) {
  251. if (type->isIntegerTy()) {
  252. return ConstantInt::get(type, 0);
  253. }
  254. else if (type->isFloatingPointTy()) {
  255. return ConstantFP::get(type, 0);
  256. }
  257. return UndefValue::get(type);
  258. }
  259. static BasicBlock *GetExitBlockForExitingBlock(Loop *L, BasicBlock *exiting_block) {
  260. BranchInst *br = dyn_cast<BranchInst>(exiting_block->getTerminator());
  261. assert(L->contains(exiting_block));
  262. assert(br->isConditional());
  263. BasicBlock *result = L->contains(br->getSuccessor(0)) ? br->getSuccessor(1) : br->getSuccessor(0);
  264. assert(!L->contains(result));
  265. return result;
  266. }
  267. // Branch over the block's content with the condition cond.
  268. // All values used outside the block is replaced by a phi.
  269. //
  270. static void SkipBlockWithBranch(BasicBlock *bb, Value *cond, Loop *L, LoopInfo *LI) {
  271. BasicBlock *body = bb->splitBasicBlock(bb->getFirstNonPHI());
  272. body->setName("dx.struct_exit.cond_body");
  273. BasicBlock *end = body->splitBasicBlock(body->getTerminator());
  274. end->setName("dx.struct_exit.cond_end");
  275. bb->getTerminator()->eraseFromParent();
  276. BranchInst::Create(end, body, cond, bb);
  277. for (Instruction &inst : *body) {
  278. PHINode *phi = nullptr;
  279. // For each user that's outside of 'body', replace its use of 'inst' with a phi created
  280. // in 'end'
  281. for (auto it = inst.user_begin(); it != inst.user_end();) {
  282. Instruction *user_inst = cast<Instruction>(*(it++));
  283. if (user_inst == phi)
  284. continue;
  285. if (user_inst->getParent() != body) {
  286. if (!phi) {
  287. phi = PHINode::Create(inst.getType(), 2, "", &*end->begin());
  288. phi->addIncoming(GetDefaultValue(inst.getType()), bb);
  289. phi->addIncoming(&inst, body);
  290. }
  291. user_inst->replaceUsesOfWith(&inst, phi);
  292. }
  293. } // For each user of inst of body
  294. } // For each inst in body
  295. L->addBasicBlockToLoop(body, *LI);
  296. L->addBasicBlockToLoop(end, *LI);
  297. }
  298. static unsigned GetNumPredecessors(BasicBlock *bb) {
  299. unsigned ret = 0;
  300. for (BasicBlock *pred : llvm::predecessors(bb)) {
  301. (void)pred;
  302. ret++;
  303. }
  304. return ret;
  305. }
  306. static bool RemoveUnstructuredLoopExitsIteration(BasicBlock *exiting_block, Loop *L, LoopInfo *LI, DominatorTree *DT) {
  307. LLVMContext &ctx = L->getHeader()->getContext();
  308. Type *i1Ty = Type::getInt1Ty(ctx);
  309. BasicBlock *exit_block = GetExitBlockForExitingBlock(L, exiting_block);
  310. BasicBlock *latch = L->getLoopLatch();
  311. BasicBlock *latch_exit = GetExitBlockForExitingBlock(L, latch);
  312. // If exiting block already dominates latch, then no need to do anything.
  313. if (DT->dominates(exiting_block, latch)) {
  314. return false;
  315. }
  316. Propagator prop;
  317. BranchInst *exiting_br = cast<BranchInst>(exiting_block->getTerminator());
  318. Value *exit_cond = exiting_br->getCondition();
  319. // When exit_block is false block, use !exit_cond as exit_cond.
  320. if (exiting_br->getSuccessor(1) == exit_block) {
  321. IRBuilder<> B(exiting_br);
  322. exit_cond = B.CreateNot(exit_cond);
  323. }
  324. BasicBlock *new_exiting_block = nullptr;
  325. std::vector<Value_Info> exit_values;
  326. std::vector<BasicBlock *> blocks_with_side_effect;
  327. // Find the values that flow into the exit block from this loop.
  328. {
  329. // Look at the lcssa phi's in the exit block.
  330. bool exit_cond_has_phi = false;
  331. for (Instruction &I : *exit_block) {
  332. if (PHINode *phi = dyn_cast<PHINode>(&I)) {
  333. // If there are values flowing out of the loop into the exit_block,
  334. // add them to the list to be propagated
  335. Value *value = phi->getIncomingValueForBlock(exiting_block);
  336. Value *false_value = nullptr;
  337. if (value == exit_cond) {
  338. false_value = ConstantInt::getFalse(i1Ty);
  339. exit_cond_has_phi = true;
  340. }
  341. else {
  342. false_value = GetDefaultValue(value->getType());
  343. }
  344. exit_values.push_back({ value, false_value, phi });
  345. }
  346. else {
  347. break;
  348. }
  349. }
  350. // If the exit condition is not among the exit phi's, add it.
  351. if (!exit_cond_has_phi) {
  352. exit_values.push_back({ exit_cond, ConstantInt::getFalse(i1Ty), nullptr });
  353. }
  354. }
  355. //
  356. // Propagate those values we just found to a block that dominates the latch
  357. //
  358. new_exiting_block = prop.Run(exit_values, exiting_block, latch, DT, L, LI, blocks_with_side_effect);
  359. // Stop now if we failed
  360. if (!new_exiting_block)
  361. return false;
  362. // If there are any blocks with side effects,
  363. for (BasicBlock *bb : blocks_with_side_effect) {
  364. Value *exit_cond_for_block = prop.Get(exit_cond, bb);
  365. SkipBlockWithBranch(bb, exit_cond_for_block, L, LI);
  366. }
  367. // Make the exiting block not exit.
  368. {
  369. BasicBlock *non_exiting_block = exiting_br->getSuccessor(exiting_br->getSuccessor(0) == exit_block ? 1 : 0);
  370. BranchInst::Create(non_exiting_block, exiting_block);
  371. exiting_br->eraseFromParent();
  372. exiting_br = nullptr;
  373. }
  374. Value *new_exit_cond = prop.Get(exit_cond, new_exiting_block);
  375. assert(new_exit_cond);
  376. // Split the block where we're now exiting from, and branch to latch exit
  377. StringRef old_name = new_exiting_block->getName();
  378. BasicBlock *new_not_exiting_block = new_exiting_block->splitBasicBlock(new_exiting_block->getFirstNonPHI());
  379. new_exiting_block->setName("dx.struct_exit.new_exiting");
  380. new_not_exiting_block->setName(old_name);
  381. L->addBasicBlockToLoop(new_not_exiting_block, *LI);
  382. // Branch to latch_exit
  383. new_exiting_block->getTerminator()->eraseFromParent();
  384. BranchInst::Create(latch_exit, new_not_exiting_block, new_exit_cond, new_exiting_block);
  385. // If the exit block and the latch exit are the same, then we're already good.
  386. // just update the phi nodes in the exit block.
  387. if (latch_exit == exit_block) {
  388. for (Value_Info &info : exit_values) {
  389. // Take the phi node in the exit block and reset incoming block and value from latch_exit
  390. PHINode *exit_phi = info.exit_phi;
  391. if (exit_phi) {
  392. for (unsigned i = 0; i < exit_phi->getNumIncomingValues(); i++) {
  393. if (exit_phi->getIncomingBlock(i) == exiting_block) {
  394. exit_phi->setIncomingBlock(i, new_exiting_block);
  395. exit_phi->setIncomingValue(i, prop.Get(info.val, new_exiting_block));
  396. }
  397. }
  398. }
  399. }
  400. }
  401. // Otherwise...
  402. else {
  403. // 1. Split the latch exit, since it's going to branch to the real exit block
  404. BasicBlock *post_exit_location = latch_exit->splitBasicBlock(latch_exit->getFirstNonPHI());
  405. {
  406. // If latch exit is part of an outer loop, add its split in there too.
  407. if (Loop *outer_loop = LI->getLoopFor(latch_exit)) {
  408. outer_loop->addBasicBlockToLoop(post_exit_location, *LI);
  409. }
  410. // If the original exit block is part of an outer loop, then latch exit (which is the
  411. // new exit block) must be part of it, since all blocks that branch to within
  412. // a loop must be part of that loop structure.
  413. else if (Loop *outer_loop = LI->getLoopFor(exit_block)) {
  414. outer_loop->addBasicBlockToLoop(latch_exit, *LI);
  415. }
  416. }
  417. // 2. Add incoming values to latch_exit's phi nodes.
  418. // Since now new exiting block is branching to latch exit, its phis need to be updated.
  419. for (Instruction &inst : *latch_exit) {
  420. PHINode *phi = dyn_cast<PHINode>(&inst);
  421. if (!phi)
  422. break;
  423. phi->addIncoming(GetDefaultValue(phi->getType()), new_exiting_block);
  424. }
  425. unsigned latch_exit_num_predecessors = GetNumPredecessors(latch_exit);
  426. PHINode *exit_cond_lcssa = nullptr;
  427. for (Value_Info &info : exit_values) {
  428. // 3. Create lcssa phi's for all the propagated values at latch_exit.
  429. // Make exit values visible in the latch_exit
  430. PHINode *val_lcssa = PHINode::Create(info.val->getType(), latch_exit_num_predecessors, "dx.struct_exit.val_lcssa", latch_exit->begin());
  431. if (info.val == exit_cond) {
  432. // Record the phi for the exit condition
  433. exit_cond_lcssa = val_lcssa;
  434. exit_cond_lcssa->setName("dx.struct_exit.exit_cond_lcssa");
  435. }
  436. for (BasicBlock *pred : llvm::predecessors(latch_exit)) {
  437. if (pred == new_exiting_block) {
  438. Value *incoming = prop.Get(info.val, new_exiting_block);
  439. assert(incoming);
  440. val_lcssa->addIncoming(incoming, pred);
  441. }
  442. else {
  443. val_lcssa->addIncoming(info.false_val, pred);
  444. }
  445. }
  446. // 4. Update the phis in the exit_block to use the lcssa phi's we just created.
  447. PHINode *exit_phi = info.exit_phi;
  448. if (exit_phi) {
  449. for (unsigned i = 0; i < exit_phi->getNumIncomingValues(); i++) {
  450. if (exit_phi->getIncomingBlock(i) == exiting_block) {
  451. exit_phi->setIncomingBlock(i, latch_exit);
  452. exit_phi->setIncomingValue(i, val_lcssa);
  453. }
  454. }
  455. }
  456. }
  457. // 5. Take the first half of latch_exit and branch it to the exit_block based
  458. // on the propagated exit condition.
  459. latch_exit->getTerminator()->eraseFromParent();
  460. BranchInst::Create(exit_block, post_exit_location, exit_cond_lcssa, latch_exit);
  461. }
  462. DT->recalculate(*L->getHeader()->getParent());
  463. assert(L->isLCSSAForm(*DT));
  464. return true;
  465. }
  466. bool hlsl::RemoveUnstructuredLoopExits(llvm::Loop *L, llvm::LoopInfo *LI, llvm::DominatorTree *DT, std::unordered_set<llvm::BasicBlock *> *exclude_set) {
  467. bool changed = false;
  468. if (!L->isLCSSAForm(*DT))
  469. return false;
  470. // Give up if loop is not rotated somehow
  471. if (BasicBlock *latch = L->getLoopLatch()) {
  472. if (!cast<BranchInst>(latch->getTerminator())->isConditional())
  473. return false;
  474. }
  475. // Give up if there's not a single latch
  476. else {
  477. return false;
  478. }
  479. for (;;) {
  480. // Recompute exiting block every time, since they could change between
  481. // iterations
  482. llvm::SmallVector<BasicBlock *, 4> exiting_blocks;
  483. L->getExitingBlocks(exiting_blocks);
  484. bool local_changed = false;
  485. for (BasicBlock *exiting_block : exiting_blocks) {
  486. auto latch = L->getLoopLatch();
  487. if (latch == exiting_block)
  488. continue;
  489. if (exclude_set && exclude_set->count(GetExitBlockForExitingBlock(L, exiting_block)))
  490. continue;
  491. // As soon as we got a success, break and start a new iteration, since
  492. // exiting blocks could have changed.
  493. local_changed = RemoveUnstructuredLoopExitsIteration(exiting_block, L, LI, DT);
  494. if (local_changed) {
  495. break;
  496. }
  497. }
  498. changed |= local_changed;
  499. if (!local_changed) {
  500. break;
  501. }
  502. }
  503. return changed;
  504. }