DxilLoopUnroll.cpp 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065
  1. //===- DxilLoopUnroll.cpp - Special Unroll for Constant Values ------------===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. //===----------------------------------------------------------------------===//
  10. //
  11. // Special loop unroll routine for creating mandatory constant values and
  12. // loops that have exits.
  13. //
  14. // Overview of algorithm:
  15. //
  16. // 1. Identify a set of blocks to unroll.
  17. //
  18. // LLVM's concept of loop excludes exit blocks, which are blocks that no
  19. // longer have a path to the loop latch. However, some exit blocks in HLSL
  20. // also need to be unrolled. For example:
  21. //
  22. // [unroll]
  23. // for (uint i = 0; i < 4; i++)
  24. // {
  25. // if (...)
  26. // {
  27. // // This block here is an exit block, since it's.
  28. // // guaranteed to exit the loop.
  29. // ...
  30. // a[i] = ...; // Indexing requires unroll.
  31. // return;
  32. // }
  33. // }
  34. //
  35. //
  36. // 2. Create LCSSA based on the new loop boundary.
  37. //
  38. // See LCSSA.cpp for more details. It creates trivial PHI nodes for any
  39. // outgoing values of the loop at the exit blocks, so when the loop body
  40. // gets cloned, the outgoing values can be added to those PHI nodes easily.
  41. //
  42. // We are using a modified LCSSA routine here because we are including some
  43. // of the original exit blocks in the unroll.
  44. //
  45. //
  46. // 3. Unroll the loop until we succeed.
  47. //
  48. // Unlike LLVM, we do not try to find a loop count before unrolling.
  49. // Instead, we unroll to find a constant terminal condition. Give up when we
  50. // fail to do so.
  51. //
  52. //
  53. //===----------------------------------------------------------------------===//
  54. #include "llvm/Pass.h"
  55. #include "llvm/Analysis/LoopPass.h"
  56. #include "llvm/Analysis/InstructionSimplify.h"
  57. #include "llvm/Analysis/AssumptionCache.h"
  58. #include "llvm/Analysis/LoopPass.h"
  59. #include "llvm/Analysis/InstructionSimplify.h"
  60. #include "llvm/Analysis/AssumptionCache.h"
  61. #include "llvm/Transforms/Scalar.h"
  62. #include "llvm/Transforms/Utils/Cloning.h"
  63. #include "llvm/Transforms/Utils/Local.h"
  64. #include "llvm/Transforms/Utils/UnrollLoop.h"
  65. #include "llvm/Transforms/Utils/SSAUpdater.h"
  66. #include "llvm/Transforms/Utils/LoopUtils.h"
  67. #include "llvm/Transforms/Utils/PromoteMemToReg.h"
  68. #include "llvm/IR/Instructions.h"
  69. #include "llvm/IR/Module.h"
  70. #include "llvm/IR/Verifier.h"
  71. #include "llvm/IR/PredIteratorCache.h"
  72. #include "llvm/Support/raw_ostream.h"
  73. #include "llvm/Support/Debug.h"
  74. #include "llvm/ADT/SetVector.h"
  75. #include "dxc/DXIL/DxilUtil.h"
  76. #include "dxc/HLSL/HLModule.h"
  77. using namespace llvm;
  78. using namespace hlsl;
  79. // Copied over from LoopUnroll.cpp - RemapInstruction()
  80. static inline void RemapInstruction(Instruction *I,
  81. ValueToValueMapTy &VMap) {
  82. for (unsigned op = 0, E = I->getNumOperands(); op != E; ++op) {
  83. Value *Op = I->getOperand(op);
  84. ValueToValueMapTy::iterator It = VMap.find(Op);
  85. if (It != VMap.end())
  86. I->setOperand(op, It->second);
  87. }
  88. if (PHINode *PN = dyn_cast<PHINode>(I)) {
  89. for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
  90. ValueToValueMapTy::iterator It = VMap.find(PN->getIncomingBlock(i));
  91. if (It != VMap.end())
  92. PN->setIncomingBlock(i, cast<BasicBlock>(It->second));
  93. }
  94. }
  95. }
  96. namespace {
  97. class DxilLoopUnroll : public LoopPass {
  98. public:
  99. static char ID;
  100. std::unordered_set<Function *> CleanedUpAlloca;
  101. unsigned MaxIterationAttempt = 0;
  102. DxilLoopUnroll(unsigned MaxIterationAttempt = 128) :
  103. LoopPass(ID),
  104. MaxIterationAttempt(MaxIterationAttempt)
  105. {
  106. initializeDxilLoopUnrollPass(*PassRegistry::getPassRegistry());
  107. }
  108. const char *getPassName() const override { return "Dxil Loop Unroll"; }
  109. bool runOnLoop(Loop *L, LPPassManager &LPM) override;
  110. void getAnalysisUsage(AnalysisUsage &AU) const override {
  111. AU.addRequired<LoopInfoWrapperPass>();
  112. AU.addRequiredID(LoopSimplifyID);
  113. AU.addRequired<AssumptionCacheTracker>();
  114. AU.addRequired<DominatorTreeWrapperPass>();
  115. AU.addPreserved<DominatorTreeWrapperPass>();
  116. }
  117. };
  118. char DxilLoopUnroll::ID;
  119. static void FailLoopUnroll(bool WarnOnly, LLVMContext &Ctx, DebugLoc DL, const char *Message) {
  120. if (WarnOnly) {
  121. if (DL)
  122. Ctx.emitWarning(hlsl::dxilutil::FormatMessageAtLocation(DL, Message));
  123. else
  124. Ctx.emitWarning(hlsl::dxilutil::FormatMessageWithoutLocation(Message));
  125. }
  126. else {
  127. if (DL)
  128. Ctx.emitError(hlsl::dxilutil::FormatMessageAtLocation(DL, Message));
  129. else
  130. Ctx.emitError(hlsl::dxilutil::FormatMessageWithoutLocation(Message));
  131. }
  132. }
  133. struct LoopIteration {
  134. SmallVector<BasicBlock *, 16> Body;
  135. BasicBlock *Latch = nullptr;
  136. BasicBlock *Header = nullptr;
  137. ValueToValueMapTy VarMap;
  138. SetVector<BasicBlock *> Extended; // Blocks that are included in the clone that are not in the core loop body.
  139. LoopIteration() {}
  140. };
  141. static bool GetConstantI1(Value *V, bool *Val=nullptr) {
  142. if (ConstantInt *C = dyn_cast<ConstantInt>(V)) {
  143. if (V->getType()->isIntegerTy(1)) {
  144. if (Val)
  145. *Val = (bool)C->getLimitedValue();
  146. return true;
  147. }
  148. }
  149. return false;
  150. }
  151. // Copied from llvm::SimplifyInstructionsInBlock
  152. static bool SimplifyInstructionsInBlock_NoDelete(BasicBlock *BB,
  153. const TargetLibraryInfo *TLI) {
  154. bool MadeChange = false;
  155. #ifndef NDEBUG
  156. // In debug builds, ensure that the terminator of the block is never replaced
  157. // or deleted by these simplifications. The idea of simplification is that it
  158. // cannot introduce new instructions, and there is no way to replace the
  159. // terminator of a block without introducing a new instruction.
  160. AssertingVH<Instruction> TerminatorVH(--BB->end());
  161. #endif
  162. for (BasicBlock::iterator BI = BB->begin(), E = --BB->end(); BI != E; ) {
  163. assert(!BI->isTerminator());
  164. Instruction *Inst = BI++;
  165. WeakVH BIHandle(BI);
  166. if (recursivelySimplifyInstruction(Inst, TLI)) {
  167. MadeChange = true;
  168. if (BIHandle != BI)
  169. BI = BB->begin();
  170. continue;
  171. }
  172. #if 0 // HLSL Change
  173. MadeChange |= RecursivelyDeleteTriviallyDeadInstructions(Inst, TLI);
  174. #endif // HLSL Change
  175. if (BIHandle != BI)
  176. BI = BB->begin();
  177. }
  178. return MadeChange;
  179. }
  180. static bool IsMarkedFullUnroll(Loop *L) {
  181. if (MDNode *LoopID = L->getLoopID())
  182. return GetUnrollMetadata(LoopID, "llvm.loop.unroll.full");
  183. return false;
  184. }
  185. static bool IsMarkedUnrollCount(Loop *L, unsigned *OutCount) {
  186. if (MDNode *LoopID = L->getLoopID()) {
  187. if (MDNode *MD = GetUnrollMetadata(LoopID, "llvm.loop.unroll.count")) {
  188. assert(MD->getNumOperands() == 2 &&
  189. "Unroll count hint metadata should have two operands.");
  190. unsigned Count =
  191. mdconst::extract<ConstantInt>(MD->getOperand(1))->getZExtValue();
  192. assert(Count >= 1 && "Unroll count must be positive.");
  193. *OutCount = Count;
  194. return true;
  195. }
  196. }
  197. return false;
  198. }
  199. static bool HasSuccessorsInLoop(BasicBlock *BB, Loop *L) {
  200. for (BasicBlock *Succ : successors(BB)) {
  201. if (L->contains(Succ)) {
  202. return true;
  203. }
  204. }
  205. return false;
  206. }
  207. static void DetachFromSuccessors(BasicBlock *BB) {
  208. SmallVector<BasicBlock *, 16> Successors(succ_begin(BB), succ_end(BB));
  209. for (BasicBlock *Succ : Successors) {
  210. Succ->removePredecessor(BB);
  211. }
  212. }
  213. /// Return true if the specified block is in the list.
  214. static bool isExitBlock(BasicBlock *BB,
  215. const SmallVectorImpl<BasicBlock *> &ExitBlocks) {
  216. for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i)
  217. if (ExitBlocks[i] == BB)
  218. return true;
  219. return false;
  220. }
  221. // Copied and modified from LCSSA.cpp
  222. static bool processInstruction(SetVector<BasicBlock *> &Body, Loop &L, Instruction &Inst, DominatorTree &DT, // HLSL Change
  223. const SmallVectorImpl<BasicBlock *> &ExitBlocks,
  224. PredIteratorCache &PredCache, LoopInfo *LI) {
  225. SmallVector<Use *, 16> UsesToRewrite;
  226. BasicBlock *InstBB = Inst.getParent();
  227. for (Use &U : Inst.uses()) {
  228. Instruction *User = cast<Instruction>(U.getUser());
  229. BasicBlock *UserBB = User->getParent();
  230. if (PHINode *PN = dyn_cast<PHINode>(User))
  231. UserBB = PN->getIncomingBlock(U);
  232. if (InstBB != UserBB && /*!L.contains(UserBB)*/!Body.count(UserBB)) // HLSL Change
  233. UsesToRewrite.push_back(&U);
  234. }
  235. // If there are no uses outside the loop, exit with no change.
  236. if (UsesToRewrite.empty())
  237. return false;
  238. #if 0 // HLSL Change
  239. ++NumLCSSA; // We are applying the transformation
  240. #endif // HLSL Change
  241. // Invoke instructions are special in that their result value is not available
  242. // along their unwind edge. The code below tests to see whether DomBB
  243. // dominates
  244. // the value, so adjust DomBB to the normal destination block, which is
  245. // effectively where the value is first usable.
  246. BasicBlock *DomBB = Inst.getParent();
  247. if (InvokeInst *Inv = dyn_cast<InvokeInst>(&Inst))
  248. DomBB = Inv->getNormalDest();
  249. DomTreeNode *DomNode = DT.getNode(DomBB);
  250. SmallVector<PHINode *, 16> AddedPHIs;
  251. SmallVector<PHINode *, 8> PostProcessPHIs;
  252. SSAUpdater SSAUpdate;
  253. SSAUpdate.Initialize(Inst.getType(), Inst.getName());
  254. // Insert the LCSSA phi's into all of the exit blocks dominated by the
  255. // value, and add them to the Phi's map.
  256. for (SmallVectorImpl<BasicBlock *>::const_iterator BBI = ExitBlocks.begin(),
  257. BBE = ExitBlocks.end();
  258. BBI != BBE; ++BBI) {
  259. BasicBlock *ExitBB = *BBI;
  260. if (!DT.dominates(DomNode, DT.getNode(ExitBB)))
  261. continue;
  262. // If we already inserted something for this BB, don't reprocess it.
  263. if (SSAUpdate.HasValueForBlock(ExitBB))
  264. continue;
  265. PHINode *PN = PHINode::Create(Inst.getType(), PredCache.size(ExitBB),
  266. Inst.getName() + ".lcssa", ExitBB->begin());
  267. // Add inputs from inside the loop for this PHI.
  268. for (BasicBlock *Pred : PredCache.get(ExitBB)) {
  269. PN->addIncoming(&Inst, Pred);
  270. // If the exit block has a predecessor not within the loop, arrange for
  271. // the incoming value use corresponding to that predecessor to be
  272. // rewritten in terms of a different LCSSA PHI.
  273. if (/*!L.contains(Pred)*/ !Body.count(Pred)) // HLSL Change
  274. UsesToRewrite.push_back(
  275. &PN->getOperandUse(PN->getOperandNumForIncomingValue(
  276. PN->getNumIncomingValues() - 1)));
  277. }
  278. AddedPHIs.push_back(PN);
  279. // Remember that this phi makes the value alive in this block.
  280. SSAUpdate.AddAvailableValue(ExitBB, PN);
  281. // LoopSimplify might fail to simplify some loops (e.g. when indirect
  282. // branches are involved). In such situations, it might happen that an exit
  283. // for Loop L1 is the header of a disjoint Loop L2. Thus, when we create
  284. // PHIs in such an exit block, we are also inserting PHIs into L2's header.
  285. // This could break LCSSA form for L2 because these inserted PHIs can also
  286. // have uses outside of L2. Remember all PHIs in such situation as to
  287. // revisit than later on. FIXME: Remove this if indirectbr support into
  288. // LoopSimplify gets improved.
  289. if (auto *OtherLoop = LI->getLoopFor(ExitBB))
  290. if (!L.contains(OtherLoop))
  291. PostProcessPHIs.push_back(PN);
  292. }
  293. // Rewrite all uses outside the loop in terms of the new PHIs we just
  294. // inserted.
  295. for (unsigned i = 0, e = UsesToRewrite.size(); i != e; ++i) {
  296. // If this use is in an exit block, rewrite to use the newly inserted PHI.
  297. // This is required for correctness because SSAUpdate doesn't handle uses in
  298. // the same block. It assumes the PHI we inserted is at the end of the
  299. // block.
  300. Instruction *User = cast<Instruction>(UsesToRewrite[i]->getUser());
  301. BasicBlock *UserBB = User->getParent();
  302. if (PHINode *PN = dyn_cast<PHINode>(User))
  303. UserBB = PN->getIncomingBlock(*UsesToRewrite[i]);
  304. if (isa<PHINode>(UserBB->begin()) && isExitBlock(UserBB, ExitBlocks)) {
  305. // Tell the VHs that the uses changed. This updates SCEV's caches.
  306. if (UsesToRewrite[i]->get()->hasValueHandle())
  307. ValueHandleBase::ValueIsRAUWd(*UsesToRewrite[i], UserBB->begin());
  308. UsesToRewrite[i]->set(UserBB->begin());
  309. continue;
  310. }
  311. // Otherwise, do full PHI insertion.
  312. SSAUpdate.RewriteUse(*UsesToRewrite[i]);
  313. }
  314. // Post process PHI instructions that were inserted into another disjoint loop
  315. // and update their exits properly.
  316. for (auto *I : PostProcessPHIs) {
  317. if (I->use_empty())
  318. continue;
  319. BasicBlock *PHIBB = I->getParent();
  320. Loop *OtherLoop = LI->getLoopFor(PHIBB);
  321. SmallVector<BasicBlock *, 8> EBs;
  322. OtherLoop->getExitBlocks(EBs);
  323. if (EBs.empty())
  324. continue;
  325. // Recurse and re-process each PHI instruction. FIXME: we should really
  326. // convert this entire thing to a worklist approach where we process a
  327. // vector of instructions...
  328. processInstruction(Body, *OtherLoop, *I, DT, EBs, PredCache, LI);
  329. }
  330. // Remove PHI nodes that did not have any uses rewritten.
  331. for (unsigned i = 0, e = AddedPHIs.size(); i != e; ++i) {
  332. if (AddedPHIs[i]->use_empty())
  333. AddedPHIs[i]->eraseFromParent();
  334. }
  335. return true;
  336. }
  337. // Copied from LCSSA.cpp
  338. static bool blockDominatesAnExit(BasicBlock *BB,
  339. DominatorTree &DT,
  340. const SmallVectorImpl<BasicBlock *> &ExitBlocks) {
  341. DomTreeNode *DomNode = DT.getNode(BB);
  342. for (BasicBlock *Exit : ExitBlocks)
  343. if (DT.dominates(DomNode, DT.getNode(Exit)))
  344. return true;
  345. return false;
  346. };
  347. // Copied from LCSSA.cpp
  348. //
  349. // We need to recreate the LCSSA form since our loop boundary is potentially different from
  350. // the canonical one.
  351. static bool CreateLCSSA(SetVector<BasicBlock *> &Body, const SmallVectorImpl<BasicBlock *> &ExitBlocks, Loop *L, DominatorTree &DT, LoopInfo *LI) {
  352. PredIteratorCache PredCache;
  353. bool Changed = false;
  354. // Look at all the instructions in the loop, checking to see if they have uses
  355. // outside the loop. If so, rewrite those uses.
  356. for (SetVector<BasicBlock *>::iterator BBI = Body.begin(), BBE = Body.end();
  357. BBI != BBE; ++BBI) {
  358. BasicBlock *BB = *BBI;
  359. // For large loops, avoid use-scanning by using dominance information: In
  360. // particular, if a block does not dominate any of the loop exits, then none
  361. // of the values defined in the block could be used outside the loop.
  362. if (!blockDominatesAnExit(BB, DT, ExitBlocks))
  363. continue;
  364. for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) {
  365. // Reject two common cases fast: instructions with no uses (like stores)
  366. // and instructions with one use that is in the same block as this.
  367. if (I->use_empty() ||
  368. (I->hasOneUse() && I->user_back()->getParent() == BB &&
  369. !isa<PHINode>(I->user_back())))
  370. continue;
  371. Changed |= processInstruction(Body, *L, *I, DT, ExitBlocks, PredCache, LI);
  372. }
  373. }
  374. return Changed;
  375. }
  376. static Value *GetGEPPtrOrigin(GEPOperator *GEP) {
  377. Value *Ptr = GEP->getPointerOperand();
  378. while (Ptr) {
  379. if (AllocaInst *AI = dyn_cast<AllocaInst>(Ptr)) {
  380. return AI;
  381. }
  382. else if (GEPOperator *NewGEP = dyn_cast<GEPOperator>(Ptr)) {
  383. Ptr = NewGEP->getPointerOperand();
  384. }
  385. else if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Ptr)) {
  386. return GV;
  387. }
  388. else {
  389. break;
  390. }
  391. }
  392. return nullptr;
  393. }
  394. // Find all blocks in the loop with instructions that
  395. // would require an unroll to be correct.
  396. //
  397. // For example:
  398. // for (int i = 0; i < 10; i++) {
  399. // gep i
  400. // }
  401. //
  402. static void FindProblemBlocks(BasicBlock *Header, const SmallVectorImpl<BasicBlock *> &BlocksInLoop, std::unordered_set<BasicBlock *> &ProblemBlocks, SetVector<AllocaInst *> &ProblemAllocas) {
  403. SmallVector<Instruction *, 16> WorkList;
  404. std::unordered_set<BasicBlock *> BlocksInLoopSet(BlocksInLoop.begin(), BlocksInLoop.end());
  405. std::unordered_set<Instruction *> InstructionsSeen;
  406. for (Instruction &I : *Header) {
  407. PHINode *PN = dyn_cast<PHINode>(&I);
  408. if (!PN)
  409. break;
  410. WorkList.push_back(PN);
  411. InstructionsSeen.insert(PN);
  412. }
  413. while (WorkList.size()) {
  414. Instruction *I = WorkList.pop_back_val();
  415. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(I)) {
  416. Type *EltType = GEP->getType()->getPointerElementType();
  417. // NOTE: This is a very convservative in the following conditions:
  418. // - constant global resource arrays with external linkage (these can be
  419. // dynamically accessed)
  420. // - global resource arrays or alloca resource arrays, as long as all
  421. // writes come from the same original resource definition (which can
  422. // also be an array).
  423. //
  424. // We may want to make this more precise in the future if it becomes a
  425. // problem.
  426. //
  427. if (hlsl::dxilutil::IsHLSLObjectType(EltType)) {
  428. if (Value *Ptr = GetGEPPtrOrigin(cast<GEPOperator>(GEP))) {
  429. if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Ptr)) {
  430. if (!GV->isExternalLinkage(llvm::GlobalValue::ExternalLinkage))
  431. ProblemBlocks.insert(GEP->getParent());
  432. }
  433. else if (AllocaInst *AI = dyn_cast<AllocaInst>(Ptr)) {
  434. ProblemAllocas.insert(AI);
  435. ProblemBlocks.insert(GEP->getParent());
  436. }
  437. }
  438. continue; // Stop Propagating
  439. }
  440. }
  441. for (User *U : I->users()) {
  442. if (Instruction *UserI = dyn_cast<Instruction>(U)) {
  443. if (!InstructionsSeen.count(UserI) &&
  444. BlocksInLoopSet.count(UserI->getParent()))
  445. {
  446. InstructionsSeen.insert(UserI);
  447. WorkList.push_back(UserI);
  448. }
  449. }
  450. }
  451. }
  452. }
  453. // Helper function for getting GEP's const index value
  454. inline static int64_t GetGEPIndex(GEPOperator *GEP, unsigned idx) {
  455. return cast<ConstantInt>(GEP->getOperand(idx + 1))->getSExtValue();
  456. }
  457. // Replace allocas with all constant indices with scalar allocas, then promote
  458. // them to values where possible (mem2reg).
  459. //
  460. // Before loop unroll, we did not have constant indices for arrays and SROA was
  461. // unable to break them into scalars. Now that unroll has potentially given
  462. // them constant values, we need to turn them into scalars.
  463. //
  464. // if "AllowOOBIndex" is true, it turns any out of bound index into 0.
  465. // Otherwise it emits an error and fails compilation.
  466. //
  467. template<typename IteratorT>
  468. static bool BreakUpArrayAllocas(bool AllowOOBIndex, IteratorT ItBegin, IteratorT ItEnd, DominatorTree *DT, AssumptionCache *AC) {
  469. bool Success = true;
  470. SmallVector<AllocaInst *, 8> WorkList(ItBegin, ItEnd);
  471. SmallVector<GEPOperator *, 16> GEPs;
  472. while (WorkList.size()) {
  473. AllocaInst *AI = WorkList.pop_back_val();
  474. Type *AllocaType = AI->getAllocatedType();
  475. // Only deal with array allocas.
  476. if (!AllocaType->isArrayTy())
  477. continue;
  478. unsigned ArraySize = AI->getAllocatedType()->getArrayNumElements();
  479. Type *ElementType = AllocaType->getArrayElementType();
  480. if (!ArraySize)
  481. continue;
  482. GEPs.clear(); // Re-use array
  483. for (User *U : AI->users()) {
  484. if (GEPOperator *GEP = dyn_cast<GEPOperator>(U)) {
  485. if (!GEP->hasAllConstantIndices() || GEP->getNumIndices() < 2 ||
  486. GetGEPIndex(GEP, 0) != 0)
  487. {
  488. GEPs.clear();
  489. break;
  490. }
  491. else {
  492. GEPs.push_back(GEP);
  493. }
  494. }
  495. else {
  496. GEPs.clear();
  497. break;
  498. }
  499. }
  500. if (!GEPs.size())
  501. continue;
  502. SmallVector<AllocaInst *, 8> ScalarAllocas;
  503. ScalarAllocas.resize(ArraySize);
  504. IRBuilder<> B(AI);
  505. for (GEPOperator *GEP : GEPs) {
  506. int64_t idx = GetGEPIndex(GEP, 1);
  507. GetElementPtrInst *GEPInst = dyn_cast<GetElementPtrInst>(GEP);
  508. if (idx < 0 || idx >= ArraySize) {
  509. if (AllowOOBIndex)
  510. idx = 0;
  511. else {
  512. Success = false;
  513. if (GEPInst)
  514. hlsl::dxilutil::EmitErrorOnInstruction(GEPInst, "Array access out of bound.");
  515. continue;
  516. }
  517. }
  518. AllocaInst *ScalarAlloca = ScalarAllocas[idx];
  519. if (!ScalarAlloca) {
  520. ScalarAlloca = B.CreateAlloca(ElementType);
  521. ScalarAllocas[idx] = ScalarAlloca;
  522. if (ElementType->isArrayTy()) {
  523. WorkList.push_back(ScalarAlloca);
  524. }
  525. }
  526. Value *NewPointer = nullptr;
  527. if (ElementType->isArrayTy()) {
  528. SmallVector<Value *, 2> Indices;
  529. Indices.push_back(B.getInt32(0));
  530. for (unsigned i = 2; i < GEP->getNumIndices(); i++) {
  531. Indices.push_back(GEP->getOperand(i + 1));
  532. }
  533. NewPointer = B.CreateGEP(ScalarAlloca, Indices);
  534. } else {
  535. NewPointer = ScalarAlloca;
  536. }
  537. GEP->replaceAllUsesWith(NewPointer);
  538. }
  539. if (!ElementType->isArrayTy()) {
  540. std::remove(ScalarAllocas.begin(), ScalarAllocas.end(), nullptr);
  541. PromoteMemToReg(ScalarAllocas, *DT, nullptr, AC);
  542. }
  543. }
  544. return Success;
  545. }
  546. static bool ContainsFloatingPointType(Type *Ty) {
  547. if (Ty->isFloatingPointTy()) {
  548. return true;
  549. }
  550. else if (Ty->isArrayTy()) {
  551. return ContainsFloatingPointType(Ty->getArrayElementType());
  552. }
  553. else if (Ty->isVectorTy()) {
  554. return ContainsFloatingPointType(Ty->getVectorElementType());
  555. }
  556. else if (Ty->isStructTy()) {
  557. for (unsigned i = 0, NumStructElms = Ty->getStructNumElements(); i < NumStructElms; i++) {
  558. if (ContainsFloatingPointType(Ty->getStructElementType(i)))
  559. return true;
  560. }
  561. }
  562. return false;
  563. }
  564. static bool Mem2Reg(Function &F, DominatorTree &DT, AssumptionCache &AC) {
  565. BasicBlock &BB = F.getEntryBlock(); // Get the entry node for the function
  566. bool Changed = false;
  567. std::vector<AllocaInst*> Allocas;
  568. while (1) {
  569. Allocas.clear();
  570. // Find allocas that are safe to promote, by looking at all instructions in
  571. // the entry node
  572. for (BasicBlock::iterator I = BB.begin(), E = --BB.end(); I != E; ++I)
  573. if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) // Is it an alloca?
  574. if (isAllocaPromotable(AI) &&
  575. (!HLModule::HasPreciseAttributeWithMetadata(AI) || !ContainsFloatingPointType(AI->getAllocatedType())))
  576. Allocas.push_back(AI);
  577. if (Allocas.empty()) break;
  578. PromoteMemToReg(Allocas, DT, nullptr, &AC);
  579. Changed = true;
  580. }
  581. return Changed;
  582. }
  583. bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
  584. bool HasExplicitLoopCount = false;
  585. unsigned UnrollCount = 0;
  586. // If the loop is not marked as [unroll], don't do anything.
  587. if (IsMarkedUnrollCount(L, &UnrollCount)) {
  588. HasExplicitLoopCount = true;
  589. }
  590. else if (!IsMarkedFullUnroll(L)) {
  591. return false;
  592. }
  593. if (!L->isSafeToClone())
  594. return false;
  595. DebugLoc LoopLoc = L->getStartLoc(); // Debug location for the start of the loop.
  596. Function *F = L->getHeader()->getParent();
  597. bool FxcCompatMode = false;
  598. if (F->getParent()->HasHLModule()) {
  599. HLModule &HM = F->getParent()->GetHLModule();
  600. FxcCompatMode = HM.GetHLOptions().bFXCCompatMode;
  601. }
  602. // Analysis passes
  603. DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
  604. AssumptionCache *AC =
  605. &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(*F);
  606. LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
  607. Loop *OuterL = L->getParentLoop();
  608. BasicBlock *Latch = L->getLoopLatch();
  609. BasicBlock *Header = L->getHeader();
  610. BasicBlock *Predecessor = L->getLoopPredecessor();
  611. const DataLayout &DL = F->getParent()->getDataLayout();
  612. // Quit if we don't have a single latch block or predecessor
  613. if (!Latch || !Predecessor) {
  614. return false;
  615. }
  616. // If the loop exit condition is not in the latch, then the loop is not rotated. Give up.
  617. if (!cast<BranchInst>(Latch->getTerminator())->isConditional()) {
  618. return false;
  619. }
  620. // Promote alloca's
  621. if (!CleanedUpAlloca.count(F)) {
  622. CleanedUpAlloca.insert(F);
  623. Mem2Reg(*F, *DT, *AC);
  624. }
  625. SmallVector<BasicBlock *, 16> ExitBlocks;
  626. L->getExitBlocks(ExitBlocks);
  627. std::unordered_set<BasicBlock *> ExitBlockSet(ExitBlocks.begin(), ExitBlocks.end());
  628. SmallVector<BasicBlock *, 16> BlocksInLoop; // Set of blocks including both body and exits
  629. BlocksInLoop.append(L->getBlocks().begin(), L->getBlocks().end());
  630. BlocksInLoop.append(ExitBlocks.begin(), ExitBlocks.end());
  631. // Heuristically find blocks that likely need to be unrolled
  632. SetVector<AllocaInst *> ProblemAllocas;
  633. std::unordered_set<BasicBlock *> ProblemBlocks;
  634. FindProblemBlocks(L->getHeader(), BlocksInLoop, ProblemBlocks, ProblemAllocas);
  635. // Keep track of the PHI nodes at the header.
  636. SmallVector<PHINode *, 16> PHIs;
  637. for (auto it = Header->begin(); it != Header->end(); it++) {
  638. if (PHINode *PN = dyn_cast<PHINode>(it)) {
  639. PHIs.push_back(PN);
  640. }
  641. else {
  642. break;
  643. }
  644. }
  645. // Quick simplification of PHINode incoming values
  646. for (PHINode *PN : PHIs) {
  647. for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
  648. Value *OldIncomingV = PN->getIncomingValue(i);
  649. if (Instruction *IncomingI = dyn_cast<Instruction>(OldIncomingV)) {
  650. if (Value *NewIncomingV = llvm::SimplifyInstruction(IncomingI, DL)) {
  651. PN->setIncomingValue(i, NewIncomingV);
  652. }
  653. }
  654. }
  655. }
  656. SetVector<BasicBlock *> ToBeCloned; // List of blocks that will be cloned.
  657. for (BasicBlock *BB : L->getBlocks()) // Include the body right away
  658. ToBeCloned.insert(BB);
  659. // Find the exit blocks that also need to be included
  660. // in the unroll.
  661. SmallVector<BasicBlock *, 8> NewExits; // New set of exit blocks as boundaries for LCSSA
  662. SmallVector<BasicBlock *, 8> FakeExits; // Set of blocks created to allow cloning original exit blocks.
  663. for (BasicBlock *BB : ExitBlocks) {
  664. bool CloneThisExitBlock = ProblemBlocks.count(BB);
  665. if (CloneThisExitBlock) {
  666. ToBeCloned.insert(BB);
  667. // If we are cloning this basic block, we must create a new exit
  668. // block for inserting LCSSA PHI nodes.
  669. BasicBlock *FakeExit = BasicBlock::Create(BB->getContext(), "loop.exit.new");
  670. F->getBasicBlockList().insert(BB, FakeExit);
  671. TerminatorInst *OldTerm = BB->getTerminator();
  672. OldTerm->removeFromParent();
  673. FakeExit->getInstList().push_back(OldTerm);
  674. BranchInst::Create(FakeExit, BB);
  675. for (BasicBlock *Succ : successors(FakeExit)) {
  676. for (Instruction &I : *Succ) {
  677. if (PHINode *PN = dyn_cast<PHINode>(&I)) {
  678. for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
  679. if (PN->getIncomingBlock(i) == BB)
  680. PN->setIncomingBlock(i, FakeExit);
  681. }
  682. }
  683. }
  684. }
  685. NewExits.push_back(FakeExit);
  686. FakeExits.push_back(FakeExit);
  687. // Update Dom tree with new exit
  688. if (!DT->getNode(FakeExit))
  689. DT->addNewBlock(FakeExit, BB);
  690. }
  691. else {
  692. // If we are not including this exit block in the unroll,
  693. // use it for LCSSA as normal.
  694. NewExits.push_back(BB);
  695. }
  696. }
  697. // Simplify the PHI nodes that have single incoming value. The original LCSSA form
  698. // (if exists) does not necessarily work for our unroll because we may be unrolling
  699. // from a different boundary.
  700. for (BasicBlock *BB : BlocksInLoop)
  701. hlsl::dxilutil::SimplifyTrivialPHIs(BB);
  702. // Re-establish LCSSA form to get ready for unrolling.
  703. CreateLCSSA(ToBeCloned, NewExits, L, *DT, LI);
  704. SmallVector<std::unique_ptr<LoopIteration>, 16> Iterations; // List of cloned iterations
  705. bool Succeeded = false;
  706. for (unsigned IterationI = 0; IterationI < this->MaxIterationAttempt; IterationI++) {
  707. LoopIteration *PrevIteration = nullptr;
  708. if (Iterations.size())
  709. PrevIteration = Iterations.back().get();
  710. Iterations.push_back(llvm::make_unique<LoopIteration>());
  711. LoopIteration &CurIteration = *Iterations.back().get();
  712. // Clone the blocks.
  713. for (BasicBlock *BB : ToBeCloned) {
  714. BasicBlock *ClonedBB = CloneBasicBlock(BB, CurIteration.VarMap);
  715. CurIteration.VarMap[BB] = ClonedBB;
  716. ClonedBB->insertInto(F, Header);
  717. if (ExitBlockSet.count(BB))
  718. CurIteration.Extended.insert(ClonedBB);
  719. CurIteration.Body.push_back(ClonedBB);
  720. // Identify the special blocks.
  721. if (BB == Latch) {
  722. CurIteration.Latch = ClonedBB;
  723. }
  724. if (BB == Header) {
  725. CurIteration.Header = ClonedBB;
  726. }
  727. }
  728. for (BasicBlock *BB : ToBeCloned) {
  729. BasicBlock *ClonedBB = cast<BasicBlock>(CurIteration.VarMap[BB]);
  730. // If branching to outside of the loop, need to update the
  731. // phi nodes there to include new values.
  732. for (BasicBlock *Succ : successors(ClonedBB)) {
  733. if (ToBeCloned.count(Succ))
  734. continue;
  735. for (Instruction &I : *Succ) {
  736. PHINode *PN = dyn_cast<PHINode>(&I);
  737. if (!PN)
  738. break;
  739. // Find the incoming value for this new block. If there is an entry
  740. // for this block in the map, then it was defined in the loop, use it.
  741. // Otherwise it came from outside the loop.
  742. Value *OldIncoming = PN->getIncomingValueForBlock(BB);
  743. Value *NewIncoming = OldIncoming;
  744. ValueToValueMapTy::iterator Itor = CurIteration.VarMap.find(OldIncoming);
  745. if (Itor != CurIteration.VarMap.end())
  746. NewIncoming = Itor->second;
  747. PN->addIncoming(NewIncoming, ClonedBB);
  748. }
  749. }
  750. }
  751. // Remap the instructions inside of cloned blocks.
  752. for (BasicBlock *BB : CurIteration.Body) {
  753. for (Instruction &I : *BB) {
  754. ::RemapInstruction(&I, CurIteration.VarMap);
  755. }
  756. }
  757. // If this is the first block
  758. if (!PrevIteration) {
  759. // Replace the phi nodes in the clone block with the values coming
  760. // from outside of the loop
  761. for (PHINode *PN : PHIs) {
  762. PHINode *ClonedPN = cast<PHINode>(CurIteration.VarMap[PN]);
  763. Value *ReplacementVal = ClonedPN->getIncomingValueForBlock(Predecessor);
  764. ClonedPN->replaceAllUsesWith(ReplacementVal);
  765. ClonedPN->eraseFromParent();
  766. CurIteration.VarMap[PN] = ReplacementVal;
  767. }
  768. }
  769. else {
  770. // Replace the phi nodes with the value defined INSIDE the previous iteration.
  771. for (PHINode *PN : PHIs) {
  772. PHINode *ClonedPN = cast<PHINode>(CurIteration.VarMap[PN]);
  773. Value *ReplacementVal = PrevIteration->VarMap[PN->getIncomingValueForBlock(Latch)];
  774. ClonedPN->replaceAllUsesWith(ReplacementVal);
  775. ClonedPN->eraseFromParent();
  776. CurIteration.VarMap[PN] = ReplacementVal;
  777. }
  778. // Make the latch of the previous iteration branch to the header
  779. // of this new iteration.
  780. if (BranchInst *BI = dyn_cast<BranchInst>(PrevIteration->Latch->getTerminator())) {
  781. for (unsigned i = 0; i < BI->getNumSuccessors(); i++) {
  782. if (BI->getSuccessor(i) == PrevIteration->Header) {
  783. BI->setSuccessor(i, CurIteration.Header);
  784. break;
  785. }
  786. }
  787. }
  788. }
  789. // Simplify instructions in the cloned blocks to create
  790. // constant exit conditions.
  791. for (BasicBlock *ClonedBB : CurIteration.Body)
  792. SimplifyInstructionsInBlock_NoDelete(ClonedBB, NULL);
  793. // Check exit condition to see if we fully unrolled the loop
  794. if (BranchInst *BI = dyn_cast<BranchInst>(CurIteration.Latch->getTerminator())) {
  795. bool Cond = false;
  796. if (GetConstantI1(BI->getCondition(), &Cond)) {
  797. if (BI->getSuccessor(Cond ? 1 : 0) == CurIteration.Header) {
  798. Succeeded = true;
  799. break;
  800. }
  801. }
  802. }
  803. // We've reached the N defined in [unroll(N)]
  804. if (HasExplicitLoopCount && IterationI+1 >= UnrollCount) {
  805. Succeeded = true;
  806. BranchInst *BI = cast<BranchInst>(CurIteration.Latch->getTerminator());
  807. BasicBlock *ExitBlock = nullptr;
  808. for (unsigned i = 0; i < BI->getNumSuccessors(); i++) {
  809. BasicBlock *Succ = BI->getSuccessor(i);
  810. if (Succ != CurIteration.Header) {
  811. ExitBlock = Succ;
  812. break;
  813. }
  814. }
  815. BranchInst *NewBI = BranchInst::Create(ExitBlock, BI);
  816. BI->replaceAllUsesWith(NewBI);
  817. BI->eraseFromParent();
  818. break;
  819. }
  820. }
  821. if (Succeeded) {
  822. // We are going to be cleaning them up later. Maker sure
  823. // they're in entry block so deleting loop blocks don't
  824. // kill them too.
  825. for (AllocaInst *AI : ProblemAllocas)
  826. DXASSERT_LOCALVAR(AI, AI->getParent() == &F->getEntryBlock(), "Alloca is not in entry block.");
  827. LoopIteration &FirstIteration = *Iterations.front().get();
  828. // Make the predecessor branch to the first new header.
  829. {
  830. BranchInst *BI = cast<BranchInst>(Predecessor->getTerminator());
  831. for (unsigned i = 0, NumSucc = BI->getNumSuccessors(); i < NumSucc; i++) {
  832. if (BI->getSuccessor(i) == Header) {
  833. BI->setSuccessor(i, FirstIteration.Header);
  834. }
  835. }
  836. }
  837. if (OuterL) {
  838. // Core body blocks need to be added to outer loop
  839. for (size_t i = 0; i < Iterations.size(); i++) {
  840. LoopIteration &Iteration = *Iterations[i].get();
  841. for (BasicBlock *BB : Iteration.Body) {
  842. if (!Iteration.Extended.count(BB)) {
  843. OuterL->addBasicBlockToLoop(BB, *LI);
  844. }
  845. }
  846. }
  847. // Our newly created exit blocks may need to be added to outer loop
  848. for (BasicBlock *BB : FakeExits) {
  849. if (HasSuccessorsInLoop(BB, OuterL))
  850. OuterL->addBasicBlockToLoop(BB, *LI);
  851. }
  852. // Cloned exit blocks may need to be added to outer loop
  853. for (size_t i = 0; i < Iterations.size(); i++) {
  854. LoopIteration &Iteration = *Iterations[i].get();
  855. for (BasicBlock *BB : Iteration.Extended) {
  856. if (HasSuccessorsInLoop(BB, OuterL))
  857. OuterL->addBasicBlockToLoop(BB, *LI);
  858. }
  859. }
  860. }
  861. // Remove the original blocks that we've cloned from all loops.
  862. for (BasicBlock *BB : ToBeCloned)
  863. LI->removeBlock(BB);
  864. LPM.deleteLoopFromQueue(L);
  865. // Remove dead blocks.
  866. for (BasicBlock *BB : ToBeCloned)
  867. DetachFromSuccessors(BB);
  868. for (BasicBlock *BB : ToBeCloned)
  869. BB->dropAllReferences();
  870. for (BasicBlock *BB : ToBeCloned)
  871. BB->eraseFromParent();
  872. // Blocks need to be removed from DomTree. There's no easy way
  873. // to remove them in the right order, so just make DomTree
  874. // recalculate.
  875. DT->recalculate(*F);
  876. if (OuterL) {
  877. // This process may have created multiple back edges for the
  878. // parent loop. Simplify to keep it well-formed.
  879. simplifyLoop(OuterL, DT, LI, this, nullptr, nullptr, AC);
  880. }
  881. // Now that we potentially turned some GEP indices into constants,
  882. // try to clean up their allocas.
  883. if (!BreakUpArrayAllocas(FxcCompatMode /* allow oob index */, ProblemAllocas.begin(), ProblemAllocas.end(), DT, AC)) {
  884. FailLoopUnroll(false, F->getContext(), LoopLoc, "Could not unroll loop due to out of bound array access.");
  885. }
  886. return true;
  887. }
  888. // If we were unsuccessful in unrolling the loop
  889. else {
  890. FailLoopUnroll(FxcCompatMode /*warn only*/, F->getContext(), LoopLoc, "Could not unroll loop.");
  891. // Remove all the cloned blocks
  892. for (std::unique_ptr<LoopIteration> &Ptr : Iterations) {
  893. LoopIteration &Iteration = *Ptr.get();
  894. for (BasicBlock *BB : Iteration.Body)
  895. DetachFromSuccessors(BB);
  896. }
  897. for (std::unique_ptr<LoopIteration> &Ptr : Iterations) {
  898. LoopIteration &Iteration = *Ptr.get();
  899. for (BasicBlock *BB : Iteration.Body)
  900. BB->dropAllReferences();
  901. }
  902. for (std::unique_ptr<LoopIteration> &Ptr : Iterations) {
  903. LoopIteration &Iteration = *Ptr.get();
  904. for (BasicBlock *BB : Iteration.Body)
  905. BB->eraseFromParent();
  906. }
  907. return false;
  908. }
  909. }
  910. }
  911. Pass *llvm::createDxilLoopUnrollPass(unsigned MaxIterationAttempt) {
  912. return new DxilLoopUnroll(MaxIterationAttempt);
  913. }
  914. INITIALIZE_PASS(DxilLoopUnroll, "dxil-loop-unroll", "Dxil Unroll loops", false, false)