DxilLoopUnroll.cpp 36 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070
  1. //===- DxilLoopUnroll.cpp - Special Unroll for Constant Values ------------===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. //===----------------------------------------------------------------------===//
  10. //
  11. // Special loop unroll routine for creating mandatory constant values and
  12. // loops that have exits.
  13. //
  14. // Overview of algorithm:
  15. //
  16. // 1. Identify a set of blocks to unroll.
  17. //
  18. // LLVM's concept of loop excludes exit blocks, which are blocks that no
  19. // longer have a path to the loop latch. However, some exit blocks in HLSL
  20. // also need to be unrolled. For example:
  21. //
  22. // [unroll]
  23. // for (uint i = 0; i < 4; i++)
  24. // {
  25. // if (...)
  26. // {
  27. // // This block here is an exit block, since it's.
  28. // // guaranteed to exit the loop.
  29. // ...
  30. // a[i] = ...; // Indexing requires unroll.
  31. // return;
  32. // }
  33. // }
  34. //
  35. //
  36. // 2. Create LCSSA based on the new loop boundary.
  37. //
  38. // See LCSSA.cpp for more details. It creates trivial PHI nodes for any
  39. // outgoing values of the loop at the exit blocks, so when the loop body
  40. // gets cloned, the outgoing values can be added to those PHI nodes easily.
  41. //
  42. // We are using a modified LCSSA routine here because we are including some
  43. // of the original exit blocks in the unroll.
  44. //
  45. //
  46. // 3. Unroll the loop until we succeed.
  47. //
  48. // Unlike LLVM, we do not try to find a loop count before unrolling.
  49. // Instead, we unroll to find a constant terminal condition. Give up when we
  50. // fail to do so.
  51. //
  52. //
  53. //===----------------------------------------------------------------------===//
  54. #include "llvm/Pass.h"
  55. #include "llvm/Analysis/LoopPass.h"
  56. #include "llvm/Analysis/InstructionSimplify.h"
  57. #include "llvm/Analysis/AssumptionCache.h"
  58. #include "llvm/Analysis/LoopPass.h"
  59. #include "llvm/Analysis/InstructionSimplify.h"
  60. #include "llvm/Analysis/AssumptionCache.h"
  61. #include "llvm/Analysis/ScalarEvolution.h"
  62. #include "llvm/Transforms/Scalar.h"
  63. #include "llvm/Transforms/Utils/Cloning.h"
  64. #include "llvm/Transforms/Utils/Local.h"
  65. #include "llvm/Transforms/Utils/UnrollLoop.h"
  66. #include "llvm/Transforms/Utils/SSAUpdater.h"
  67. #include "llvm/Transforms/Utils/LoopUtils.h"
  68. #include "llvm/Transforms/Utils/PromoteMemToReg.h"
  69. #include "llvm/IR/Instructions.h"
  70. #include "llvm/IR/Module.h"
  71. #include "llvm/IR/Verifier.h"
  72. #include "llvm/IR/PredIteratorCache.h"
  73. #include "llvm/IR/IntrinsicInst.h"
  74. #include "llvm/Support/raw_ostream.h"
  75. #include "llvm/Support/Debug.h"
  76. #include "llvm/ADT/SetVector.h"
  77. #include "llvm/IR/LegacyPassManager.h"
  78. #include "llvm/Analysis/ScalarEvolutionExpressions.h"
  79. #include "dxc/DXIL/DxilUtil.h"
  80. #include "dxc/HLSL/HLModule.h"
  81. #include "llvm/Analysis/DxilValueCache.h"
  82. using namespace llvm;
  83. using namespace hlsl;
  84. // Copied over from LoopUnroll.cpp - RemapInstruction()
  85. static inline void RemapInstruction(Instruction *I,
  86. ValueToValueMapTy &VMap) {
  87. for (unsigned op = 0, E = I->getNumOperands(); op != E; ++op) {
  88. Value *Op = I->getOperand(op);
  89. ValueToValueMapTy::iterator It = VMap.find(Op);
  90. if (It != VMap.end())
  91. I->setOperand(op, It->second);
  92. }
  93. if (PHINode *PN = dyn_cast<PHINode>(I)) {
  94. for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
  95. ValueToValueMapTy::iterator It = VMap.find(PN->getIncomingBlock(i));
  96. if (It != VMap.end())
  97. PN->setIncomingBlock(i, cast<BasicBlock>(It->second));
  98. }
  99. }
  100. }
  101. namespace {
  102. class DxilLoopUnroll : public LoopPass {
  103. public:
  104. static char ID;
  105. std::unordered_set<Function *> CleanedUpAlloca;
  106. const unsigned MaxIterationAttempt;
  107. DxilLoopUnroll(unsigned MaxIterationAttempt = 1024) :
  108. LoopPass(ID),
  109. MaxIterationAttempt(MaxIterationAttempt)
  110. {
  111. initializeDxilLoopUnrollPass(*PassRegistry::getPassRegistry());
  112. }
  113. const char *getPassName() const override { return "Dxil Loop Unroll"; }
  114. bool runOnLoop(Loop *L, LPPassManager &LPM) override;
  115. void getAnalysisUsage(AnalysisUsage &AU) const override {
  116. AU.addRequired<LoopInfoWrapperPass>();
  117. AU.addRequired<AssumptionCacheTracker>();
  118. AU.addRequired<DominatorTreeWrapperPass>();
  119. AU.addPreserved<DominatorTreeWrapperPass>();
  120. AU.addRequired<ScalarEvolution>();
  121. AU.addRequired<DxilValueCache>();
  122. AU.addRequiredID(LoopSimplifyID);
  123. }
  124. };
  125. char DxilLoopUnroll::ID;
  126. static void FailLoopUnroll(bool WarnOnly, LLVMContext &Ctx, DebugLoc DL, const Twine &Message) {
  127. if (WarnOnly) {
  128. if (DL)
  129. Ctx.emitWarning(hlsl::dxilutil::FormatMessageAtLocation(DL, Message));
  130. else
  131. Ctx.emitWarning(hlsl::dxilutil::FormatMessageWithoutLocation(Message));
  132. }
  133. else {
  134. if (DL)
  135. Ctx.emitError(hlsl::dxilutil::FormatMessageAtLocation(DL, Message));
  136. else
  137. Ctx.emitError(hlsl::dxilutil::FormatMessageWithoutLocation(Message));
  138. }
  139. }
  140. struct LoopIteration {
  141. SmallVector<BasicBlock *, 16> Body;
  142. BasicBlock *Latch = nullptr;
  143. BasicBlock *Header = nullptr;
  144. ValueToValueMapTy VarMap;
  145. SetVector<BasicBlock *> Extended; // Blocks that are included in the clone that are not in the core loop body.
  146. LoopIteration() {}
  147. };
  148. static bool GetConstantI1(Value *V, bool *Val=nullptr) {
  149. if (ConstantInt *C = dyn_cast<ConstantInt>(V)) {
  150. if (V->getType()->isIntegerTy(1)) {
  151. if (Val)
  152. *Val = (bool)C->getLimitedValue();
  153. return true;
  154. }
  155. }
  156. return false;
  157. }
  158. static bool IsMarkedFullUnroll(Loop *L) {
  159. if (MDNode *LoopID = L->getLoopID())
  160. return GetUnrollMetadata(LoopID, "llvm.loop.unroll.full");
  161. return false;
  162. }
  163. static bool IsMarkedUnrollCount(Loop *L, int *OutCount) {
  164. if (MDNode *LoopID = L->getLoopID()) {
  165. if (MDNode *MD = GetUnrollMetadata(LoopID, "llvm.loop.unroll.count")) {
  166. assert(MD->getNumOperands() == 2 &&
  167. "Unroll count hint metadata should have two operands.");
  168. ConstantInt *Val = mdconst::extract<ConstantInt>(MD->getOperand(1));
  169. int Count = Val->getZExtValue();
  170. *OutCount = Count;
  171. return true;
  172. }
  173. }
  174. return false;
  175. }
  176. static bool HasSuccessorsInLoop(BasicBlock *BB, Loop *L) {
  177. for (BasicBlock *Succ : successors(BB)) {
  178. if (L->contains(Succ)) {
  179. return true;
  180. }
  181. }
  182. return false;
  183. }
  184. static void DetachFromSuccessors(BasicBlock *BB) {
  185. SmallVector<BasicBlock *, 16> Successors(succ_begin(BB), succ_end(BB));
  186. for (BasicBlock *Succ : Successors) {
  187. Succ->removePredecessor(BB);
  188. }
  189. }
  190. /// Return true if the specified block is in the list.
  191. static bool isExitBlock(BasicBlock *BB,
  192. const SmallVectorImpl<BasicBlock *> &ExitBlocks) {
  193. for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i)
  194. if (ExitBlocks[i] == BB)
  195. return true;
  196. return false;
  197. }
  198. // Copied and modified from LCSSA.cpp
  199. static bool processInstruction(SetVector<BasicBlock *> &Body, Loop &L, Instruction &Inst, DominatorTree &DT, // HLSL Change
  200. const SmallVectorImpl<BasicBlock *> &ExitBlocks,
  201. PredIteratorCache &PredCache, LoopInfo *LI) {
  202. SmallVector<Use *, 16> UsesToRewrite;
  203. BasicBlock *InstBB = Inst.getParent();
  204. for (Use &U : Inst.uses()) {
  205. Instruction *User = cast<Instruction>(U.getUser());
  206. BasicBlock *UserBB = User->getParent();
  207. if (PHINode *PN = dyn_cast<PHINode>(User))
  208. UserBB = PN->getIncomingBlock(U);
  209. if (InstBB != UserBB && /*!L.contains(UserBB)*/!Body.count(UserBB)) // HLSL Change
  210. UsesToRewrite.push_back(&U);
  211. }
  212. // If there are no uses outside the loop, exit with no change.
  213. if (UsesToRewrite.empty())
  214. return false;
  215. #if 0 // HLSL Change
  216. ++NumLCSSA; // We are applying the transformation
  217. #endif // HLSL Change
  218. // Invoke instructions are special in that their result value is not available
  219. // along their unwind edge. The code below tests to see whether DomBB
  220. // dominates
  221. // the value, so adjust DomBB to the normal destination block, which is
  222. // effectively where the value is first usable.
  223. BasicBlock *DomBB = Inst.getParent();
  224. if (InvokeInst *Inv = dyn_cast<InvokeInst>(&Inst))
  225. DomBB = Inv->getNormalDest();
  226. DomTreeNode *DomNode = DT.getNode(DomBB);
  227. SmallVector<PHINode *, 16> AddedPHIs;
  228. SmallVector<PHINode *, 8> PostProcessPHIs;
  229. SSAUpdater SSAUpdate;
  230. SSAUpdate.Initialize(Inst.getType(), Inst.getName());
  231. // Insert the LCSSA phi's into all of the exit blocks dominated by the
  232. // value, and add them to the Phi's map.
  233. for (SmallVectorImpl<BasicBlock *>::const_iterator BBI = ExitBlocks.begin(),
  234. BBE = ExitBlocks.end();
  235. BBI != BBE; ++BBI) {
  236. BasicBlock *ExitBB = *BBI;
  237. if (!DT.dominates(DomNode, DT.getNode(ExitBB)))
  238. continue;
  239. // If we already inserted something for this BB, don't reprocess it.
  240. if (SSAUpdate.HasValueForBlock(ExitBB))
  241. continue;
  242. PHINode *PN = PHINode::Create(Inst.getType(), PredCache.size(ExitBB),
  243. Inst.getName() + ".lcssa", ExitBB->begin());
  244. // Add inputs from inside the loop for this PHI.
  245. for (BasicBlock *Pred : PredCache.get(ExitBB)) {
  246. PN->addIncoming(&Inst, Pred);
  247. // If the exit block has a predecessor not within the loop, arrange for
  248. // the incoming value use corresponding to that predecessor to be
  249. // rewritten in terms of a different LCSSA PHI.
  250. if (/*!L.contains(Pred)*/ !Body.count(Pred)) // HLSL Change
  251. UsesToRewrite.push_back(
  252. &PN->getOperandUse(PN->getOperandNumForIncomingValue(
  253. PN->getNumIncomingValues() - 1)));
  254. }
  255. AddedPHIs.push_back(PN);
  256. // Remember that this phi makes the value alive in this block.
  257. SSAUpdate.AddAvailableValue(ExitBB, PN);
  258. // LoopSimplify might fail to simplify some loops (e.g. when indirect
  259. // branches are involved). In such situations, it might happen that an exit
  260. // for Loop L1 is the header of a disjoint Loop L2. Thus, when we create
  261. // PHIs in such an exit block, we are also inserting PHIs into L2's header.
  262. // This could break LCSSA form for L2 because these inserted PHIs can also
  263. // have uses outside of L2. Remember all PHIs in such situation as to
  264. // revisit than later on. FIXME: Remove this if indirectbr support into
  265. // LoopSimplify gets improved.
  266. if (auto *OtherLoop = LI->getLoopFor(ExitBB))
  267. if (!L.contains(OtherLoop))
  268. PostProcessPHIs.push_back(PN);
  269. }
  270. // Rewrite all uses outside the loop in terms of the new PHIs we just
  271. // inserted.
  272. for (unsigned i = 0, e = UsesToRewrite.size(); i != e; ++i) {
  273. // If this use is in an exit block, rewrite to use the newly inserted PHI.
  274. // This is required for correctness because SSAUpdate doesn't handle uses in
  275. // the same block. It assumes the PHI we inserted is at the end of the
  276. // block.
  277. Instruction *User = cast<Instruction>(UsesToRewrite[i]->getUser());
  278. BasicBlock *UserBB = User->getParent();
  279. if (PHINode *PN = dyn_cast<PHINode>(User))
  280. UserBB = PN->getIncomingBlock(*UsesToRewrite[i]);
  281. if (isa<PHINode>(UserBB->begin()) && isExitBlock(UserBB, ExitBlocks)) {
  282. // Tell the VHs that the uses changed. This updates SCEV's caches.
  283. if (UsesToRewrite[i]->get()->hasValueHandle())
  284. ValueHandleBase::ValueIsRAUWd(*UsesToRewrite[i], UserBB->begin());
  285. UsesToRewrite[i]->set(UserBB->begin());
  286. continue;
  287. }
  288. // Otherwise, do full PHI insertion.
  289. SSAUpdate.RewriteUse(*UsesToRewrite[i]);
  290. }
  291. // Post process PHI instructions that were inserted into another disjoint loop
  292. // and update their exits properly.
  293. for (auto *I : PostProcessPHIs) {
  294. if (I->use_empty())
  295. continue;
  296. BasicBlock *PHIBB = I->getParent();
  297. Loop *OtherLoop = LI->getLoopFor(PHIBB);
  298. SmallVector<BasicBlock *, 8> EBs;
  299. OtherLoop->getExitBlocks(EBs);
  300. if (EBs.empty())
  301. continue;
  302. // Recurse and re-process each PHI instruction. FIXME: we should really
  303. // convert this entire thing to a worklist approach where we process a
  304. // vector of instructions...
  305. SetVector<BasicBlock *> OtherLoopBody(OtherLoop->block_begin(), OtherLoop->block_end()); // HLSL Change
  306. processInstruction(OtherLoopBody, *OtherLoop, *I, DT, EBs, PredCache, LI);
  307. }
  308. // Remove PHI nodes that did not have any uses rewritten.
  309. for (unsigned i = 0, e = AddedPHIs.size(); i != e; ++i) {
  310. if (AddedPHIs[i]->use_empty())
  311. AddedPHIs[i]->eraseFromParent();
  312. }
  313. return true;
  314. }
  315. // Copied from LCSSA.cpp
  316. static bool blockDominatesAnExit(BasicBlock *BB,
  317. DominatorTree &DT,
  318. const SmallVectorImpl<BasicBlock *> &ExitBlocks) {
  319. DomTreeNode *DomNode = DT.getNode(BB);
  320. for (BasicBlock *Exit : ExitBlocks)
  321. if (DT.dominates(DomNode, DT.getNode(Exit)))
  322. return true;
  323. return false;
  324. };
  325. // Copied from LCSSA.cpp
  326. //
  327. // We need to recreate the LCSSA form since our loop boundary is potentially different from
  328. // the canonical one.
  329. static bool CreateLCSSA(SetVector<BasicBlock *> &Body, const SmallVectorImpl<BasicBlock *> &ExitBlocks, Loop *L, DominatorTree &DT, LoopInfo *LI) {
  330. PredIteratorCache PredCache;
  331. bool Changed = false;
  332. // Look at all the instructions in the loop, checking to see if they have uses
  333. // outside the loop. If so, rewrite those uses.
  334. for (SetVector<BasicBlock *>::iterator BBI = Body.begin(), BBE = Body.end();
  335. BBI != BBE; ++BBI) {
  336. BasicBlock *BB = *BBI;
  337. // For large loops, avoid use-scanning by using dominance information: In
  338. // particular, if a block does not dominate any of the loop exits, then none
  339. // of the values defined in the block could be used outside the loop.
  340. if (!blockDominatesAnExit(BB, DT, ExitBlocks))
  341. continue;
  342. for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) {
  343. // Reject two common cases fast: instructions with no uses (like stores)
  344. // and instructions with one use that is in the same block as this.
  345. if (I->use_empty() ||
  346. (I->hasOneUse() && I->user_back()->getParent() == BB &&
  347. !isa<PHINode>(I->user_back())))
  348. continue;
  349. Changed |= processInstruction(Body, *L, *I, DT, ExitBlocks, PredCache, LI);
  350. }
  351. }
  352. return Changed;
  353. }
  354. static Value *GetGEPPtrOrigin(GEPOperator *GEP) {
  355. Value *Ptr = GEP->getPointerOperand();
  356. while (Ptr) {
  357. if (AllocaInst *AI = dyn_cast<AllocaInst>(Ptr)) {
  358. return AI;
  359. }
  360. else if (GEPOperator *NewGEP = dyn_cast<GEPOperator>(Ptr)) {
  361. Ptr = NewGEP->getPointerOperand();
  362. }
  363. else if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Ptr)) {
  364. return GV;
  365. }
  366. else {
  367. break;
  368. }
  369. }
  370. return nullptr;
  371. }
  372. // Find all blocks in the loop with instructions that
  373. // would require an unroll to be correct.
  374. //
  375. // For example:
  376. // for (int i = 0; i < 10; i++) {
  377. // gep i
  378. // }
  379. //
  380. static void FindProblemBlocks(BasicBlock *Header, const SmallVectorImpl<BasicBlock *> &BlocksInLoop, std::unordered_set<BasicBlock *> &ProblemBlocks, SetVector<AllocaInst *> &ProblemAllocas) {
  381. SmallVector<Instruction *, 16> WorkList;
  382. std::unordered_set<BasicBlock *> BlocksInLoopSet(BlocksInLoop.begin(), BlocksInLoop.end());
  383. std::unordered_set<Instruction *> InstructionsSeen;
  384. for (Instruction &I : *Header) {
  385. PHINode *PN = dyn_cast<PHINode>(&I);
  386. if (!PN)
  387. break;
  388. WorkList.push_back(PN);
  389. InstructionsSeen.insert(PN);
  390. }
  391. while (WorkList.size()) {
  392. Instruction *I = WorkList.pop_back_val();
  393. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(I)) {
  394. Type *EltType = GEP->getType()->getPointerElementType();
  395. // NOTE: This is a very convservative in the following conditions:
  396. // - constant global resource arrays with external linkage (these can be
  397. // dynamically accessed)
  398. // - global resource arrays or alloca resource arrays, as long as all
  399. // writes come from the same original resource definition (which can
  400. // also be an array).
  401. //
  402. // We may want to make this more precise in the future if it becomes a
  403. // problem.
  404. //
  405. if (hlsl::dxilutil::IsHLSLObjectType(EltType)) {
  406. if (Value *Ptr = GetGEPPtrOrigin(cast<GEPOperator>(GEP))) {
  407. if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Ptr)) {
  408. if (!GV->isExternalLinkage(llvm::GlobalValue::ExternalLinkage))
  409. ProblemBlocks.insert(GEP->getParent());
  410. }
  411. else if (AllocaInst *AI = dyn_cast<AllocaInst>(Ptr)) {
  412. ProblemAllocas.insert(AI);
  413. ProblemBlocks.insert(GEP->getParent());
  414. }
  415. }
  416. continue; // Stop Propagating
  417. }
  418. }
  419. for (User *U : I->users()) {
  420. if (Instruction *UserI = dyn_cast<Instruction>(U)) {
  421. if (!InstructionsSeen.count(UserI) &&
  422. BlocksInLoopSet.count(UserI->getParent()))
  423. {
  424. InstructionsSeen.insert(UserI);
  425. WorkList.push_back(UserI);
  426. }
  427. }
  428. }
  429. }
  430. }
  431. // Helper function for getting GEP's const index value
  432. inline static int64_t GetGEPIndex(GEPOperator *GEP, unsigned idx) {
  433. return cast<ConstantInt>(GEP->getOperand(idx + 1))->getSExtValue();
  434. }
  435. // Replace allocas with all constant indices with scalar allocas, then promote
  436. // them to values where possible (mem2reg).
  437. //
  438. // Before loop unroll, we did not have constant indices for arrays and SROA was
  439. // unable to break them into scalars. Now that unroll has potentially given
  440. // them constant values, we need to turn them into scalars.
  441. //
  442. // if "AllowOOBIndex" is true, it turns any out of bound index into 0.
  443. // Otherwise it emits an error and fails compilation.
  444. //
  445. template<typename IteratorT>
  446. static bool BreakUpArrayAllocas(bool AllowOOBIndex, IteratorT ItBegin, IteratorT ItEnd, DominatorTree *DT, AssumptionCache *AC, DxilValueCache *DVC) {
  447. bool Success = true;
  448. SmallVector<AllocaInst *, 8> WorkList(ItBegin, ItEnd);
  449. SmallVector<GEPOperator *, 16> GEPs;
  450. while (WorkList.size()) {
  451. AllocaInst *AI = WorkList.pop_back_val();
  452. Type *AllocaType = AI->getAllocatedType();
  453. // Only deal with array allocas.
  454. if (!AllocaType->isArrayTy())
  455. continue;
  456. unsigned ArraySize = AI->getAllocatedType()->getArrayNumElements();
  457. Type *ElementType = AllocaType->getArrayElementType();
  458. if (!ArraySize)
  459. continue;
  460. GEPs.clear(); // Re-use array
  461. for (User *U : AI->users()) {
  462. // Try to set all GEP operands to constant
  463. if (GEPOperator *GEP = dyn_cast<GEPOperator>(U)) {
  464. if (!GEP->hasAllConstantIndices() && isa<GetElementPtrInst>(GEP)) {
  465. for (unsigned i = 0; i < GEP->getNumIndices(); i++) {
  466. Value *IndexOp = GEP->getOperand(i + 1);
  467. if (isa<Constant>(IndexOp))
  468. continue;
  469. if (Constant *C = DVC->GetConstValue(IndexOp))
  470. GEP->setOperand(i + 1, C);
  471. }
  472. }
  473. if (!GEP->hasAllConstantIndices() || GEP->getNumIndices() < 2 ||
  474. GetGEPIndex(GEP, 0) != 0)
  475. {
  476. GEPs.clear();
  477. break;
  478. }
  479. else {
  480. GEPs.push_back(GEP);
  481. }
  482. }
  483. else {
  484. GEPs.clear();
  485. break;
  486. }
  487. }
  488. if (!GEPs.size())
  489. continue;
  490. SmallVector<AllocaInst *, 8> ScalarAllocas;
  491. ScalarAllocas.resize(ArraySize);
  492. IRBuilder<> B(AI);
  493. for (GEPOperator *GEP : GEPs) {
  494. int64_t idx = GetGEPIndex(GEP, 1);
  495. GetElementPtrInst *GEPInst = dyn_cast<GetElementPtrInst>(GEP);
  496. if (idx < 0 || idx >= ArraySize) {
  497. if (AllowOOBIndex)
  498. idx = 0;
  499. else {
  500. Success = false;
  501. if (GEPInst)
  502. hlsl::dxilutil::EmitErrorOnInstruction(GEPInst, "Array access out of bound.");
  503. continue;
  504. }
  505. }
  506. AllocaInst *ScalarAlloca = ScalarAllocas[idx];
  507. if (!ScalarAlloca) {
  508. ScalarAlloca = B.CreateAlloca(ElementType);
  509. ScalarAllocas[idx] = ScalarAlloca;
  510. if (ElementType->isArrayTy()) {
  511. WorkList.push_back(ScalarAlloca);
  512. }
  513. }
  514. Value *NewPointer = nullptr;
  515. if (ElementType->isArrayTy()) {
  516. SmallVector<Value *, 2> Indices;
  517. Indices.push_back(B.getInt32(0));
  518. for (unsigned i = 2; i < GEP->getNumIndices(); i++) {
  519. Indices.push_back(GEP->getOperand(i + 1));
  520. }
  521. NewPointer = B.CreateGEP(ScalarAlloca, Indices);
  522. } else {
  523. NewPointer = ScalarAlloca;
  524. }
  525. GEP->replaceAllUsesWith(NewPointer);
  526. }
  527. if (!ElementType->isArrayTy()) {
  528. ScalarAllocas.erase(std::remove(ScalarAllocas.begin(), ScalarAllocas.end(), nullptr), ScalarAllocas.end());
  529. PromoteMemToReg(ScalarAllocas, *DT, nullptr, AC);
  530. }
  531. }
  532. return Success;
  533. }
  534. static void RecursivelyRemoveLoopFromQueue(LPPassManager &LPM, Loop *L) {
  535. // Copy the sub loops into a separate list because
  536. // the original list may change.
  537. SmallVector<Loop *, 4> SubLoops(L->getSubLoops().begin(), L->getSubLoops().end());
  538. // Must remove all child loops first.
  539. for (Loop *SubL : SubLoops) {
  540. RecursivelyRemoveLoopFromQueue(LPM, SubL);
  541. }
  542. LPM.deleteLoopFromQueue(L);
  543. }
  544. bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
  545. DebugLoc LoopLoc = L->getStartLoc(); // Debug location for the start of the loop.
  546. Function *F = L->getHeader()->getParent();
  547. ScalarEvolution *SE = &getAnalysis<ScalarEvolution>();
  548. DxilValueCache *DVC = &getAnalysis<DxilValueCache>();
  549. bool HasExplicitLoopCount = false;
  550. int ExplicitUnrollCountSigned = 0;
  551. // If the loop is not marked as [unroll], don't do anything.
  552. if (IsMarkedUnrollCount(L, &ExplicitUnrollCountSigned)) {
  553. HasExplicitLoopCount = true;
  554. }
  555. else if (!IsMarkedFullUnroll(L)) {
  556. return false;
  557. }
  558. unsigned ExplicitUnrollCount = 0;
  559. if (HasExplicitLoopCount) {
  560. if (ExplicitUnrollCountSigned < 1) {
  561. FailLoopUnroll(false, F->getContext(), LoopLoc, "Could not unroll loop. Invalid unroll count.");
  562. return false;
  563. }
  564. ExplicitUnrollCount = (unsigned)ExplicitUnrollCountSigned;
  565. }
  566. if (!L->isSafeToClone())
  567. return false;
  568. bool FxcCompatMode = false;
  569. if (F->getParent()->HasHLModule()) {
  570. HLModule &HM = F->getParent()->GetHLModule();
  571. FxcCompatMode = HM.GetHLOptions().bFXCCompatMode;
  572. }
  573. unsigned TripCount = 0;
  574. BasicBlock *ExitingBlock = L->getLoopLatch();
  575. if (!ExitingBlock || !L->isLoopExiting(ExitingBlock))
  576. ExitingBlock = L->getExitingBlock();
  577. if (ExitingBlock) {
  578. TripCount = SE->getSmallConstantTripCount(L, ExitingBlock);
  579. }
  580. // Analysis passes
  581. DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
  582. AssumptionCache *AC =
  583. &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(*F);
  584. LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
  585. Loop *OuterL = L->getParentLoop();
  586. BasicBlock *Latch = L->getLoopLatch();
  587. BasicBlock *Header = L->getHeader();
  588. BasicBlock *Predecessor = L->getLoopPredecessor();
  589. const DataLayout &DL = F->getParent()->getDataLayout();
  590. // Quit if we don't have a single latch block or predecessor
  591. if (!Latch || !Predecessor) {
  592. return false;
  593. }
  594. // If the loop exit condition is not in the latch, then the loop is not rotated. Give up.
  595. if (!cast<BranchInst>(Latch->getTerminator())->isConditional()) {
  596. return false;
  597. }
  598. SmallVector<BasicBlock *, 16> ExitBlocks;
  599. L->getExitBlocks(ExitBlocks);
  600. std::unordered_set<BasicBlock *> ExitBlockSet(ExitBlocks.begin(), ExitBlocks.end());
  601. SmallVector<BasicBlock *, 16> BlocksInLoop; // Set of blocks including both body and exits
  602. BlocksInLoop.append(L->getBlocks().begin(), L->getBlocks().end());
  603. BlocksInLoop.append(ExitBlocks.begin(), ExitBlocks.end());
  604. // Heuristically find blocks that likely need to be unrolled
  605. SetVector<AllocaInst *> ProblemAllocas;
  606. std::unordered_set<BasicBlock *> ProblemBlocks;
  607. FindProblemBlocks(L->getHeader(), BlocksInLoop, ProblemBlocks, ProblemAllocas);
  608. // Keep track of the PHI nodes at the header.
  609. SmallVector<PHINode *, 16> PHIs;
  610. for (auto it = Header->begin(); it != Header->end(); it++) {
  611. if (PHINode *PN = dyn_cast<PHINode>(it)) {
  612. PHIs.push_back(PN);
  613. }
  614. else {
  615. break;
  616. }
  617. }
  618. // Quick simplification of PHINode incoming values
  619. for (PHINode *PN : PHIs) {
  620. for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
  621. Value *OldIncomingV = PN->getIncomingValue(i);
  622. if (Instruction *IncomingI = dyn_cast<Instruction>(OldIncomingV)) {
  623. if (Value *NewIncomingV = llvm::SimplifyInstruction(IncomingI, DL)) {
  624. PN->setIncomingValue(i, NewIncomingV);
  625. }
  626. }
  627. }
  628. }
  629. SetVector<BasicBlock *> ToBeCloned; // List of blocks that will be cloned.
  630. for (BasicBlock *BB : L->getBlocks()) // Include the body right away
  631. ToBeCloned.insert(BB);
  632. // Find the exit blocks that also need to be included
  633. // in the unroll.
  634. SmallVector<BasicBlock *, 8> NewExits; // New set of exit blocks as boundaries for LCSSA
  635. SmallVector<BasicBlock *, 8> FakeExits; // Set of blocks created to allow cloning original exit blocks.
  636. for (BasicBlock *BB : ExitBlocks) {
  637. bool CloneThisExitBlock = ProblemBlocks.count(BB);
  638. if (CloneThisExitBlock) {
  639. ToBeCloned.insert(BB);
  640. // If we are cloning this basic block, we must create a new exit
  641. // block for inserting LCSSA PHI nodes.
  642. BasicBlock *FakeExit = BasicBlock::Create(BB->getContext(), "loop.exit.new");
  643. F->getBasicBlockList().insert(BB, FakeExit);
  644. TerminatorInst *OldTerm = BB->getTerminator();
  645. OldTerm->removeFromParent();
  646. FakeExit->getInstList().push_back(OldTerm);
  647. BranchInst::Create(FakeExit, BB);
  648. for (BasicBlock *Succ : successors(FakeExit)) {
  649. for (Instruction &I : *Succ) {
  650. if (PHINode *PN = dyn_cast<PHINode>(&I)) {
  651. for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
  652. if (PN->getIncomingBlock(i) == BB)
  653. PN->setIncomingBlock(i, FakeExit);
  654. }
  655. }
  656. }
  657. }
  658. NewExits.push_back(FakeExit);
  659. FakeExits.push_back(FakeExit);
  660. // Update Dom tree with new exit
  661. if (!DT->getNode(FakeExit))
  662. DT->addNewBlock(FakeExit, BB);
  663. }
  664. else {
  665. // If we are not including this exit block in the unroll,
  666. // use it for LCSSA as normal.
  667. NewExits.push_back(BB);
  668. }
  669. }
  670. // Simplify the PHI nodes that have single incoming value. The original LCSSA form
  671. // (if exists) does not necessarily work for our unroll because we may be unrolling
  672. // from a different boundary.
  673. for (BasicBlock *BB : BlocksInLoop)
  674. hlsl::dxilutil::SimplifyTrivialPHIs(BB);
  675. // Re-establish LCSSA form to get ready for unrolling.
  676. CreateLCSSA(ToBeCloned, NewExits, L, *DT, LI);
  677. SmallVector<std::unique_ptr<LoopIteration>, 16> Iterations; // List of cloned iterations
  678. bool Succeeded = false;
  679. unsigned MaxAttempt = this->MaxIterationAttempt;
  680. // If we were able to figure out the definitive trip count,
  681. // just unroll that many times.
  682. if (TripCount != 0) {
  683. MaxAttempt = TripCount;
  684. }
  685. else if (HasExplicitLoopCount) {
  686. MaxAttempt = ExplicitUnrollCount;
  687. }
  688. for (unsigned IterationI = 0; IterationI < MaxAttempt; IterationI++) {
  689. LoopIteration *PrevIteration = nullptr;
  690. if (Iterations.size())
  691. PrevIteration = Iterations.back().get();
  692. Iterations.push_back(llvm::make_unique<LoopIteration>());
  693. LoopIteration &CurIteration = *Iterations.back().get();
  694. // Clone the blocks.
  695. for (BasicBlock *BB : ToBeCloned) {
  696. BasicBlock *ClonedBB = CloneBasicBlock(BB, CurIteration.VarMap);
  697. CurIteration.VarMap[BB] = ClonedBB;
  698. ClonedBB->insertInto(F, Header);
  699. if (ExitBlockSet.count(BB))
  700. CurIteration.Extended.insert(ClonedBB);
  701. CurIteration.Body.push_back(ClonedBB);
  702. // Identify the special blocks.
  703. if (BB == Latch) {
  704. CurIteration.Latch = ClonedBB;
  705. }
  706. if (BB == Header) {
  707. CurIteration.Header = ClonedBB;
  708. }
  709. }
  710. for (BasicBlock *BB : ToBeCloned) {
  711. BasicBlock *ClonedBB = cast<BasicBlock>(CurIteration.VarMap[BB]);
  712. // If branching to outside of the loop, need to update the
  713. // phi nodes there to include new values.
  714. for (BasicBlock *Succ : successors(ClonedBB)) {
  715. if (ToBeCloned.count(Succ))
  716. continue;
  717. for (Instruction &I : *Succ) {
  718. PHINode *PN = dyn_cast<PHINode>(&I);
  719. if (!PN)
  720. break;
  721. // Find the incoming value for this new block. If there is an entry
  722. // for this block in the map, then it was defined in the loop, use it.
  723. // Otherwise it came from outside the loop.
  724. Value *OldIncoming = PN->getIncomingValueForBlock(BB);
  725. Value *NewIncoming = OldIncoming;
  726. ValueToValueMapTy::iterator Itor = CurIteration.VarMap.find(OldIncoming);
  727. if (Itor != CurIteration.VarMap.end())
  728. NewIncoming = Itor->second;
  729. PN->addIncoming(NewIncoming, ClonedBB);
  730. }
  731. }
  732. }
  733. // Remap the instructions inside of cloned blocks.
  734. for (BasicBlock *BB : CurIteration.Body) {
  735. for (Instruction &I : *BB) {
  736. ::RemapInstruction(&I, CurIteration.VarMap);
  737. }
  738. }
  739. // If this is the first block
  740. if (!PrevIteration) {
  741. // Replace the phi nodes in the clone block with the values coming
  742. // from outside of the loop
  743. for (PHINode *PN : PHIs) {
  744. PHINode *ClonedPN = cast<PHINode>(CurIteration.VarMap[PN]);
  745. Value *ReplacementVal = ClonedPN->getIncomingValueForBlock(Predecessor);
  746. ClonedPN->replaceAllUsesWith(ReplacementVal);
  747. ClonedPN->eraseFromParent();
  748. CurIteration.VarMap[PN] = ReplacementVal;
  749. }
  750. }
  751. else {
  752. // Replace the phi nodes with the value defined INSIDE the previous iteration.
  753. for (PHINode *PN : PHIs) {
  754. PHINode *ClonedPN = cast<PHINode>(CurIteration.VarMap[PN]);
  755. Value *ReplacementVal = PN->getIncomingValueForBlock(Latch);
  756. auto itRep = PrevIteration->VarMap.find(ReplacementVal);
  757. if (itRep != PrevIteration->VarMap.end())
  758. ReplacementVal = itRep->second;
  759. ClonedPN->replaceAllUsesWith(ReplacementVal);
  760. ClonedPN->eraseFromParent();
  761. CurIteration.VarMap[PN] = ReplacementVal;
  762. }
  763. // Make the latch of the previous iteration branch to the header
  764. // of this new iteration.
  765. if (BranchInst *BI = dyn_cast<BranchInst>(PrevIteration->Latch->getTerminator())) {
  766. for (unsigned i = 0; i < BI->getNumSuccessors(); i++) {
  767. if (BI->getSuccessor(i) == PrevIteration->Header) {
  768. BI->setSuccessor(i, CurIteration.Header);
  769. break;
  770. }
  771. }
  772. }
  773. }
  774. // Check exit condition to see if we fully unrolled the loop
  775. if (BranchInst *BI = dyn_cast<BranchInst>(CurIteration.Latch->getTerminator())) {
  776. bool Cond = false;
  777. Value *ConstantCond = BI->getCondition();
  778. if (Value *C = DVC->GetValue(ConstantCond))
  779. ConstantCond = C;
  780. if (GetConstantI1(ConstantCond, &Cond)) {
  781. if (BI->getSuccessor(Cond ? 1 : 0) == CurIteration.Header) {
  782. Succeeded = true;
  783. break;
  784. }
  785. }
  786. }
  787. // We've reached the N defined in [unroll(N)]
  788. if ((HasExplicitLoopCount && IterationI+1 >= ExplicitUnrollCount) ||
  789. (TripCount != 0 && IterationI+1 >= TripCount))
  790. {
  791. Succeeded = true;
  792. BranchInst *BI = cast<BranchInst>(CurIteration.Latch->getTerminator());
  793. BasicBlock *ExitBlock = nullptr;
  794. for (unsigned i = 0; i < BI->getNumSuccessors(); i++) {
  795. BasicBlock *Succ = BI->getSuccessor(i);
  796. if (Succ != CurIteration.Header) {
  797. ExitBlock = Succ;
  798. break;
  799. }
  800. }
  801. BranchInst *NewBI = BranchInst::Create(ExitBlock, BI);
  802. BI->replaceAllUsesWith(NewBI);
  803. BI->eraseFromParent();
  804. break;
  805. }
  806. }
  807. if (Succeeded) {
  808. // We are going to be cleaning them up later. Maker sure
  809. // they're in entry block so deleting loop blocks don't
  810. // kill them too.
  811. for (AllocaInst *AI : ProblemAllocas)
  812. DXASSERT_LOCALVAR(AI, AI->getParent() == &F->getEntryBlock(), "Alloca is not in entry block.");
  813. LoopIteration &FirstIteration = *Iterations.front().get();
  814. // Make the predecessor branch to the first new header.
  815. {
  816. BranchInst *BI = cast<BranchInst>(Predecessor->getTerminator());
  817. for (unsigned i = 0, NumSucc = BI->getNumSuccessors(); i < NumSucc; i++) {
  818. if (BI->getSuccessor(i) == Header) {
  819. BI->setSuccessor(i, FirstIteration.Header);
  820. }
  821. }
  822. }
  823. if (OuterL) {
  824. // Core body blocks need to be added to outer loop
  825. for (size_t i = 0; i < Iterations.size(); i++) {
  826. LoopIteration &Iteration = *Iterations[i].get();
  827. for (BasicBlock *BB : Iteration.Body) {
  828. if (!Iteration.Extended.count(BB)) {
  829. OuterL->addBasicBlockToLoop(BB, *LI);
  830. }
  831. }
  832. }
  833. // Our newly created exit blocks may need to be added to outer loop
  834. for (BasicBlock *BB : FakeExits) {
  835. if (HasSuccessorsInLoop(BB, OuterL))
  836. OuterL->addBasicBlockToLoop(BB, *LI);
  837. }
  838. // Cloned exit blocks may need to be added to outer loop
  839. for (size_t i = 0; i < Iterations.size(); i++) {
  840. LoopIteration &Iteration = *Iterations[i].get();
  841. for (BasicBlock *BB : Iteration.Extended) {
  842. if (HasSuccessorsInLoop(BB, OuterL))
  843. OuterL->addBasicBlockToLoop(BB, *LI);
  844. }
  845. }
  846. }
  847. SE->forgetLoop(L);
  848. // Remove the original blocks that we've cloned from all loops.
  849. for (BasicBlock *BB : ToBeCloned)
  850. LI->removeBlock(BB);
  851. // Remove loop and all child loops from queue.
  852. RecursivelyRemoveLoopFromQueue(LPM, L);
  853. // Remove dead blocks.
  854. for (BasicBlock *BB : ToBeCloned)
  855. DetachFromSuccessors(BB);
  856. for (BasicBlock *BB : ToBeCloned)
  857. BB->dropAllReferences();
  858. for (BasicBlock *BB : ToBeCloned)
  859. BB->eraseFromParent();
  860. // Blocks need to be removed from DomTree. There's no easy way
  861. // to remove them in the right order, so just make DomTree
  862. // recalculate.
  863. DT->recalculate(*F);
  864. if (OuterL) {
  865. // This process may have created multiple back edges for the
  866. // parent loop. Simplify to keep it well-formed.
  867. simplifyLoop(OuterL, DT, LI, this, nullptr, nullptr, AC);
  868. }
  869. // Now that we potentially turned some GEP indices into constants,
  870. // try to clean up their allocas.
  871. if (!BreakUpArrayAllocas(FxcCompatMode /* allow oob index */, ProblemAllocas.begin(), ProblemAllocas.end(), DT, AC, DVC)) {
  872. FailLoopUnroll(false, F->getContext(), LoopLoc, "Could not unroll loop due to out of bound array access.");
  873. }
  874. return true;
  875. }
  876. // If we were unsuccessful in unrolling the loop
  877. else {
  878. const char *Msg =
  879. "Could not unroll loop. Loop bound could not be deduced at compile time. "
  880. "Use [unroll(n)] to give an explicit count.";
  881. if (FxcCompatMode) {
  882. FailLoopUnroll(true /*warn only*/, F->getContext(), LoopLoc, Msg);
  883. }
  884. else {
  885. FailLoopUnroll(false /*warn only*/, F->getContext(), LoopLoc,
  886. Twine(Msg) + Twine(" Use '-HV 2016' to treat this as warning."));
  887. }
  888. // Remove all the cloned blocks
  889. for (std::unique_ptr<LoopIteration> &Ptr : Iterations) {
  890. LoopIteration &Iteration = *Ptr.get();
  891. for (BasicBlock *BB : Iteration.Body)
  892. DetachFromSuccessors(BB);
  893. }
  894. for (std::unique_ptr<LoopIteration> &Ptr : Iterations) {
  895. LoopIteration &Iteration = *Ptr.get();
  896. for (BasicBlock *BB : Iteration.Body)
  897. BB->dropAllReferences();
  898. }
  899. for (std::unique_ptr<LoopIteration> &Ptr : Iterations) {
  900. LoopIteration &Iteration = *Ptr.get();
  901. for (BasicBlock *BB : Iteration.Body)
  902. BB->eraseFromParent();
  903. }
  904. return false;
  905. }
  906. }
  907. }
  908. Pass *llvm::createDxilLoopUnrollPass(unsigned MaxIterationAttempt) {
  909. return new DxilLoopUnroll(MaxIterationAttempt);
  910. }
  911. INITIALIZE_PASS_BEGIN(DxilLoopUnroll, "dxil-loop-unroll", "Dxil Unroll loops", false, false)
  912. INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
  913. INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
  914. INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
  915. INITIALIZE_PASS_DEPENDENCY(ScalarEvolution)
  916. INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
  917. INITIALIZE_PASS_DEPENDENCY(DxilValueCache)
  918. INITIALIZE_PASS_END(DxilLoopUnroll, "dxil-loop-unroll", "Dxil Unroll loops", false, false)