loop_unswitch_pass.cpp 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898
  1. // Copyright (c) 2018 Google LLC.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/loop_unswitch_pass.h"
  15. #include <functional>
  16. #include <list>
  17. #include <memory>
  18. #include <type_traits>
  19. #include <unordered_map>
  20. #include <unordered_set>
  21. #include <utility>
  22. #include <vector>
  23. #include "source/opt/basic_block.h"
  24. #include "source/opt/dominator_tree.h"
  25. #include "source/opt/fold.h"
  26. #include "source/opt/function.h"
  27. #include "source/opt/instruction.h"
  28. #include "source/opt/ir_builder.h"
  29. #include "source/opt/ir_context.h"
  30. #include "source/opt/loop_descriptor.h"
  31. #include "source/opt/loop_utils.h"
  32. namespace spvtools {
  33. namespace opt {
  34. namespace {
  35. static const uint32_t kTypePointerStorageClassInIdx = 0;
  36. static const uint32_t kBranchCondTrueLabIdInIdx = 1;
  37. static const uint32_t kBranchCondFalseLabIdInIdx = 2;
  38. } // anonymous namespace
  39. namespace {
  40. // This class handle the unswitch procedure for a given loop.
  41. // The unswitch will not happen if:
  42. // - The loop has any instruction that will prevent it;
  43. // - The loop invariant condition is not uniform.
  44. class LoopUnswitch {
  45. public:
  46. LoopUnswitch(IRContext* context, Function* function, Loop* loop,
  47. LoopDescriptor* loop_desc)
  48. : function_(function),
  49. loop_(loop),
  50. loop_desc_(*loop_desc),
  51. context_(context),
  52. switch_block_(nullptr) {}
  53. // Returns true if the loop can be unswitched.
  54. // Can be unswitch if:
  55. // - The loop has no instructions that prevents it (such as barrier);
  56. // - The loop has one conditional branch or switch that do not depends on the
  57. // loop;
  58. // - The loop invariant condition is uniform;
  59. bool CanUnswitchLoop() {
  60. if (switch_block_) return true;
  61. if (loop_->IsSafeToClone()) return false;
  62. CFG& cfg = *context_->cfg();
  63. for (uint32_t bb_id : loop_->GetBlocks()) {
  64. BasicBlock* bb = cfg.block(bb_id);
  65. if (bb->terminator()->IsBranch() &&
  66. bb->terminator()->opcode() != SpvOpBranch) {
  67. if (IsConditionLoopInvariant(bb->terminator())) {
  68. switch_block_ = bb;
  69. break;
  70. }
  71. }
  72. }
  73. return switch_block_;
  74. }
  75. // Return the iterator to the basic block |bb|.
  76. Function::iterator FindBasicBlockPosition(BasicBlock* bb_to_find) {
  77. Function::iterator it = function_->FindBlock(bb_to_find->id());
  78. assert(it != function_->end() && "Basic Block not found");
  79. return it;
  80. }
  81. // Creates a new basic block and insert it into the function |fn| at the
  82. // position |ip|. This function preserves the def/use and instr to block
  83. // managers.
  84. BasicBlock* CreateBasicBlock(Function::iterator ip) {
  85. analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
  86. BasicBlock* bb = &*ip.InsertBefore(std::unique_ptr<BasicBlock>(
  87. new BasicBlock(std::unique_ptr<Instruction>(new Instruction(
  88. context_, SpvOpLabel, 0, context_->TakeNextId(), {})))));
  89. bb->SetParent(function_);
  90. def_use_mgr->AnalyzeInstDef(bb->GetLabelInst());
  91. context_->set_instr_block(bb->GetLabelInst(), bb);
  92. return bb;
  93. }
  94. // Unswitches |loop_|.
  95. void PerformUnswitch() {
  96. assert(CanUnswitchLoop() &&
  97. "Cannot unswitch if there is not constant condition");
  98. assert(loop_->GetPreHeaderBlock() && "This loop has no pre-header block");
  99. assert(loop_->IsLCSSA() && "This loop is not in LCSSA form");
  100. CFG& cfg = *context_->cfg();
  101. DominatorTree* dom_tree =
  102. &context_->GetDominatorAnalysis(function_)->GetDomTree();
  103. analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
  104. LoopUtils loop_utils(context_, loop_);
  105. //////////////////////////////////////////////////////////////////////////////
  106. // Step 1: Create the if merge block for structured modules.
  107. // To do so, the |loop_| merge block will become the if's one and we
  108. // create a merge for the loop. This will limit the amount of duplicated
  109. // code the structured control flow imposes.
  110. // For non structured program, the new loop will be connected to
  111. // the old loop's exit blocks.
  112. //////////////////////////////////////////////////////////////////////////////
  113. // Get the merge block if it exists.
  114. BasicBlock* if_merge_block = loop_->GetMergeBlock();
  115. // The merge block is only created if the loop has a unique exit block. We
  116. // have this guarantee for structured loops, for compute loop it will
  117. // trivially help maintain both a structured-like form and LCSAA.
  118. BasicBlock* loop_merge_block =
  119. if_merge_block
  120. ? CreateBasicBlock(FindBasicBlockPosition(if_merge_block))
  121. : nullptr;
  122. if (loop_merge_block) {
  123. // Add the instruction and update managers.
  124. InstructionBuilder builder(
  125. context_, loop_merge_block,
  126. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  127. builder.AddBranch(if_merge_block->id());
  128. builder.SetInsertPoint(&*loop_merge_block->begin());
  129. cfg.RegisterBlock(loop_merge_block);
  130. def_use_mgr->AnalyzeInstDef(loop_merge_block->GetLabelInst());
  131. // Update CFG.
  132. if_merge_block->ForEachPhiInst(
  133. [loop_merge_block, &builder, this](Instruction* phi) {
  134. Instruction* cloned = phi->Clone(context_);
  135. builder.AddInstruction(std::unique_ptr<Instruction>(cloned));
  136. phi->SetInOperand(0, {cloned->result_id()});
  137. phi->SetInOperand(1, {loop_merge_block->id()});
  138. for (uint32_t j = phi->NumInOperands() - 1; j > 1; j--)
  139. phi->RemoveInOperand(j);
  140. });
  141. // Copy the predecessor list (will get invalidated otherwise).
  142. std::vector<uint32_t> preds = cfg.preds(if_merge_block->id());
  143. for (uint32_t pid : preds) {
  144. if (pid == loop_merge_block->id()) continue;
  145. BasicBlock* p_bb = cfg.block(pid);
  146. p_bb->ForEachSuccessorLabel(
  147. [if_merge_block, loop_merge_block](uint32_t* id) {
  148. if (*id == if_merge_block->id()) *id = loop_merge_block->id();
  149. });
  150. cfg.AddEdge(pid, loop_merge_block->id());
  151. }
  152. cfg.RemoveNonExistingEdges(if_merge_block->id());
  153. // Update loop descriptor.
  154. if (Loop* ploop = loop_->GetParent()) {
  155. ploop->AddBasicBlock(loop_merge_block);
  156. loop_desc_.SetBasicBlockToLoop(loop_merge_block->id(), ploop);
  157. }
  158. // Update the dominator tree.
  159. DominatorTreeNode* loop_merge_dtn =
  160. dom_tree->GetOrInsertNode(loop_merge_block);
  161. DominatorTreeNode* if_merge_block_dtn =
  162. dom_tree->GetOrInsertNode(if_merge_block);
  163. loop_merge_dtn->parent_ = if_merge_block_dtn->parent_;
  164. loop_merge_dtn->children_.push_back(if_merge_block_dtn);
  165. loop_merge_dtn->parent_->children_.push_back(loop_merge_dtn);
  166. if_merge_block_dtn->parent_->children_.erase(std::find(
  167. if_merge_block_dtn->parent_->children_.begin(),
  168. if_merge_block_dtn->parent_->children_.end(), if_merge_block_dtn));
  169. loop_->SetMergeBlock(loop_merge_block);
  170. }
  171. ////////////////////////////////////////////////////////////////////////////
  172. // Step 2: Build a new preheader for |loop_|, use the old one
  173. // for the constant branch.
  174. ////////////////////////////////////////////////////////////////////////////
  175. BasicBlock* if_block = loop_->GetPreHeaderBlock();
  176. // If this preheader is the parent loop header,
  177. // we need to create a dedicated block for the if.
  178. BasicBlock* loop_pre_header =
  179. CreateBasicBlock(++FindBasicBlockPosition(if_block));
  180. InstructionBuilder(
  181. context_, loop_pre_header,
  182. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping)
  183. .AddBranch(loop_->GetHeaderBlock()->id());
  184. if_block->tail()->SetInOperand(0, {loop_pre_header->id()});
  185. // Update loop descriptor.
  186. if (Loop* ploop = loop_desc_[if_block]) {
  187. ploop->AddBasicBlock(loop_pre_header);
  188. loop_desc_.SetBasicBlockToLoop(loop_pre_header->id(), ploop);
  189. }
  190. // Update the CFG.
  191. cfg.RegisterBlock(loop_pre_header);
  192. def_use_mgr->AnalyzeInstDef(loop_pre_header->GetLabelInst());
  193. cfg.AddEdge(if_block->id(), loop_pre_header->id());
  194. cfg.RemoveNonExistingEdges(loop_->GetHeaderBlock()->id());
  195. loop_->GetHeaderBlock()->ForEachPhiInst(
  196. [loop_pre_header, if_block](Instruction* phi) {
  197. phi->ForEachInId([loop_pre_header, if_block](uint32_t* id) {
  198. if (*id == if_block->id()) {
  199. *id = loop_pre_header->id();
  200. }
  201. });
  202. });
  203. loop_->SetPreHeaderBlock(loop_pre_header);
  204. // Update the dominator tree.
  205. DominatorTreeNode* loop_pre_header_dtn =
  206. dom_tree->GetOrInsertNode(loop_pre_header);
  207. DominatorTreeNode* if_block_dtn = dom_tree->GetTreeNode(if_block);
  208. loop_pre_header_dtn->parent_ = if_block_dtn;
  209. assert(
  210. if_block_dtn->children_.size() == 1 &&
  211. "A loop preheader should only have the header block as a child in the "
  212. "dominator tree");
  213. loop_pre_header_dtn->children_.push_back(if_block_dtn->children_[0]);
  214. if_block_dtn->children_.clear();
  215. if_block_dtn->children_.push_back(loop_pre_header_dtn);
  216. // Make domination queries valid.
  217. dom_tree->ResetDFNumbering();
  218. // Compute an ordered list of basic block to clone: loop blocks + pre-header
  219. // + merge block.
  220. loop_->ComputeLoopStructuredOrder(&ordered_loop_blocks_, true, true);
  221. /////////////////////////////
  222. // Do the actual unswitch: //
  223. // - Clone the loop //
  224. // - Connect exits //
  225. // - Specialize the loop //
  226. /////////////////////////////
  227. Instruction* iv_condition = &*switch_block_->tail();
  228. SpvOp iv_opcode = iv_condition->opcode();
  229. Instruction* condition =
  230. def_use_mgr->GetDef(iv_condition->GetOperand(0).words[0]);
  231. analysis::ConstantManager* cst_mgr = context_->get_constant_mgr();
  232. const analysis::Type* cond_type =
  233. context_->get_type_mgr()->GetType(condition->type_id());
  234. // Build the list of value for which we need to clone and specialize the
  235. // loop.
  236. std::vector<std::pair<Instruction*, BasicBlock*>> constant_branch;
  237. // Special case for the original loop
  238. Instruction* original_loop_constant_value;
  239. BasicBlock* original_loop_target;
  240. if (iv_opcode == SpvOpBranchConditional) {
  241. constant_branch.emplace_back(
  242. cst_mgr->GetDefiningInstruction(cst_mgr->GetConstant(cond_type, {0})),
  243. nullptr);
  244. original_loop_constant_value =
  245. cst_mgr->GetDefiningInstruction(cst_mgr->GetConstant(cond_type, {1}));
  246. } else {
  247. // We are looking to take the default branch, so we can't provide a
  248. // specific value.
  249. original_loop_constant_value = nullptr;
  250. for (uint32_t i = 2; i < iv_condition->NumInOperands(); i += 2) {
  251. constant_branch.emplace_back(
  252. cst_mgr->GetDefiningInstruction(cst_mgr->GetConstant(
  253. cond_type, iv_condition->GetInOperand(i).words)),
  254. nullptr);
  255. }
  256. }
  257. // Get the loop landing pads.
  258. std::unordered_set<uint32_t> if_merging_blocks;
  259. std::function<bool(uint32_t)> is_from_original_loop;
  260. if (loop_->GetHeaderBlock()->GetLoopMergeInst()) {
  261. if_merging_blocks.insert(if_merge_block->id());
  262. is_from_original_loop = [this](uint32_t id) {
  263. return loop_->IsInsideLoop(id) || loop_->GetMergeBlock()->id() == id;
  264. };
  265. } else {
  266. loop_->GetExitBlocks(&if_merging_blocks);
  267. is_from_original_loop = [this](uint32_t id) {
  268. return loop_->IsInsideLoop(id);
  269. };
  270. }
  271. for (auto& specialisation_pair : constant_branch) {
  272. Instruction* specialisation_value = specialisation_pair.first;
  273. //////////////////////////////////////////////////////////
  274. // Step 3: Duplicate |loop_|.
  275. //////////////////////////////////////////////////////////
  276. LoopUtils::LoopCloningResult clone_result;
  277. Loop* cloned_loop =
  278. loop_utils.CloneLoop(&clone_result, ordered_loop_blocks_);
  279. specialisation_pair.second = cloned_loop->GetPreHeaderBlock();
  280. ////////////////////////////////////
  281. // Step 4: Specialize the loop. //
  282. ////////////////////////////////////
  283. {
  284. std::unordered_set<uint32_t> dead_blocks;
  285. std::unordered_set<uint32_t> unreachable_merges;
  286. SimplifyLoop(
  287. make_range(
  288. UptrVectorIterator<BasicBlock>(&clone_result.cloned_bb_,
  289. clone_result.cloned_bb_.begin()),
  290. UptrVectorIterator<BasicBlock>(&clone_result.cloned_bb_,
  291. clone_result.cloned_bb_.end())),
  292. cloned_loop, condition, specialisation_value, &dead_blocks);
  293. // We tagged dead blocks, create the loop before we invalidate any basic
  294. // block.
  295. cloned_loop =
  296. CleanLoopNest(cloned_loop, dead_blocks, &unreachable_merges);
  297. CleanUpCFG(
  298. UptrVectorIterator<BasicBlock>(&clone_result.cloned_bb_,
  299. clone_result.cloned_bb_.begin()),
  300. dead_blocks, unreachable_merges);
  301. ///////////////////////////////////////////////////////////
  302. // Step 5: Connect convergent edges to the landing pads. //
  303. ///////////////////////////////////////////////////////////
  304. for (uint32_t merge_bb_id : if_merging_blocks) {
  305. BasicBlock* merge = context_->cfg()->block(merge_bb_id);
  306. // We are in LCSSA so we only care about phi instructions.
  307. merge->ForEachPhiInst([is_from_original_loop, &dead_blocks,
  308. &clone_result](Instruction* phi) {
  309. uint32_t num_in_operands = phi->NumInOperands();
  310. for (uint32_t i = 0; i < num_in_operands; i += 2) {
  311. uint32_t pred = phi->GetSingleWordInOperand(i + 1);
  312. if (is_from_original_loop(pred)) {
  313. pred = clone_result.value_map_.at(pred);
  314. if (!dead_blocks.count(pred)) {
  315. uint32_t incoming_value_id = phi->GetSingleWordInOperand(i);
  316. // Not all the incoming value are coming from the loop.
  317. ValueMapTy::iterator new_value =
  318. clone_result.value_map_.find(incoming_value_id);
  319. if (new_value != clone_result.value_map_.end()) {
  320. incoming_value_id = new_value->second;
  321. }
  322. phi->AddOperand({SPV_OPERAND_TYPE_ID, {incoming_value_id}});
  323. phi->AddOperand({SPV_OPERAND_TYPE_ID, {pred}});
  324. }
  325. }
  326. }
  327. });
  328. }
  329. }
  330. function_->AddBasicBlocks(clone_result.cloned_bb_.begin(),
  331. clone_result.cloned_bb_.end(),
  332. ++FindBasicBlockPosition(if_block));
  333. }
  334. // Same as above but specialize the existing loop
  335. {
  336. std::unordered_set<uint32_t> dead_blocks;
  337. std::unordered_set<uint32_t> unreachable_merges;
  338. SimplifyLoop(make_range(function_->begin(), function_->end()), loop_,
  339. condition, original_loop_constant_value, &dead_blocks);
  340. for (uint32_t merge_bb_id : if_merging_blocks) {
  341. BasicBlock* merge = context_->cfg()->block(merge_bb_id);
  342. // LCSSA, so we only care about phi instructions.
  343. // If we the phi is reduced to a single incoming branch, do not
  344. // propagate it to preserve LCSSA.
  345. PatchPhis(merge, dead_blocks, true);
  346. }
  347. if (if_merge_block) {
  348. bool has_live_pred = false;
  349. for (uint32_t pid : cfg.preds(if_merge_block->id())) {
  350. if (!dead_blocks.count(pid)) {
  351. has_live_pred = true;
  352. break;
  353. }
  354. }
  355. if (!has_live_pred) unreachable_merges.insert(if_merge_block->id());
  356. }
  357. original_loop_target = loop_->GetPreHeaderBlock();
  358. // We tagged dead blocks, prune the loop descriptor from any dead loops.
  359. // After this call, |loop_| can be nullptr (i.e. the unswitch killed this
  360. // loop).
  361. loop_ = CleanLoopNest(loop_, dead_blocks, &unreachable_merges);
  362. CleanUpCFG(function_->begin(), dead_blocks, unreachable_merges);
  363. }
  364. /////////////////////////////////////
  365. // Finally: connect the new loops. //
  366. /////////////////////////////////////
  367. // Delete the old jump
  368. context_->KillInst(&*if_block->tail());
  369. InstructionBuilder builder(context_, if_block);
  370. if (iv_opcode == SpvOpBranchConditional) {
  371. assert(constant_branch.size() == 1);
  372. builder.AddConditionalBranch(
  373. condition->result_id(), original_loop_target->id(),
  374. constant_branch[0].second->id(),
  375. if_merge_block ? if_merge_block->id() : kInvalidId);
  376. } else {
  377. std::vector<std::pair<Operand::OperandData, uint32_t>> targets;
  378. for (auto& t : constant_branch) {
  379. targets.emplace_back(t.first->GetInOperand(0).words, t.second->id());
  380. }
  381. builder.AddSwitch(condition->result_id(), original_loop_target->id(),
  382. targets,
  383. if_merge_block ? if_merge_block->id() : kInvalidId);
  384. }
  385. switch_block_ = nullptr;
  386. ordered_loop_blocks_.clear();
  387. context_->InvalidateAnalysesExceptFor(
  388. IRContext::Analysis::kAnalysisLoopAnalysis);
  389. }
  390. // Returns true if the unswitch killed the original |loop_|.
  391. bool WasLoopKilled() const { return loop_ == nullptr; }
  392. private:
  393. using ValueMapTy = std::unordered_map<uint32_t, uint32_t>;
  394. using BlockMapTy = std::unordered_map<uint32_t, BasicBlock*>;
  395. Function* function_;
  396. Loop* loop_;
  397. LoopDescriptor& loop_desc_;
  398. IRContext* context_;
  399. BasicBlock* switch_block_;
  400. // Map between instructions and if they are dynamically uniform.
  401. std::unordered_map<uint32_t, bool> dynamically_uniform_;
  402. // The loop basic blocks in structured order.
  403. std::vector<BasicBlock*> ordered_loop_blocks_;
  404. // Returns the next usable id for the context.
  405. uint32_t TakeNextId() { return context_->TakeNextId(); }
  406. // Patches |bb|'s phi instruction by removing incoming value from unexisting
  407. // or tagged as dead branches.
  408. void PatchPhis(BasicBlock* bb,
  409. const std::unordered_set<uint32_t>& dead_blocks,
  410. bool preserve_phi) {
  411. CFG& cfg = *context_->cfg();
  412. std::vector<Instruction*> phi_to_kill;
  413. const std::vector<uint32_t>& bb_preds = cfg.preds(bb->id());
  414. auto is_branch_dead = [&bb_preds, &dead_blocks](uint32_t id) {
  415. return dead_blocks.count(id) ||
  416. std::find(bb_preds.begin(), bb_preds.end(), id) == bb_preds.end();
  417. };
  418. bb->ForEachPhiInst([&phi_to_kill, &is_branch_dead, preserve_phi,
  419. this](Instruction* insn) {
  420. uint32_t i = 0;
  421. while (i < insn->NumInOperands()) {
  422. uint32_t incoming_id = insn->GetSingleWordInOperand(i + 1);
  423. if (is_branch_dead(incoming_id)) {
  424. // Remove the incoming block id operand.
  425. insn->RemoveInOperand(i + 1);
  426. // Remove the definition id operand.
  427. insn->RemoveInOperand(i);
  428. continue;
  429. }
  430. i += 2;
  431. }
  432. // If there is only 1 remaining edge, propagate the value and
  433. // kill the instruction.
  434. if (insn->NumInOperands() == 2 && !preserve_phi) {
  435. phi_to_kill.push_back(insn);
  436. context_->ReplaceAllUsesWith(insn->result_id(),
  437. insn->GetSingleWordInOperand(0));
  438. }
  439. });
  440. for (Instruction* insn : phi_to_kill) {
  441. context_->KillInst(insn);
  442. }
  443. }
  444. // Removes any block that is tagged as dead, if the block is in
  445. // |unreachable_merges| then all block's instructions are replaced by a
  446. // OpUnreachable.
  447. void CleanUpCFG(UptrVectorIterator<BasicBlock> bb_it,
  448. const std::unordered_set<uint32_t>& dead_blocks,
  449. const std::unordered_set<uint32_t>& unreachable_merges) {
  450. CFG& cfg = *context_->cfg();
  451. while (bb_it != bb_it.End()) {
  452. BasicBlock& bb = *bb_it;
  453. if (unreachable_merges.count(bb.id())) {
  454. if (bb.begin() != bb.tail() ||
  455. bb.terminator()->opcode() != SpvOpUnreachable) {
  456. // Make unreachable, but leave the label.
  457. bb.KillAllInsts(false);
  458. InstructionBuilder(context_, &bb).AddUnreachable();
  459. cfg.RemoveNonExistingEdges(bb.id());
  460. }
  461. ++bb_it;
  462. } else if (dead_blocks.count(bb.id())) {
  463. cfg.ForgetBlock(&bb);
  464. // Kill this block.
  465. bb.KillAllInsts(true);
  466. bb_it = bb_it.Erase();
  467. } else {
  468. cfg.RemoveNonExistingEdges(bb.id());
  469. ++bb_it;
  470. }
  471. }
  472. }
  473. // Return true if |c_inst| is a Boolean constant and set |cond_val| with the
  474. // value that |c_inst|
  475. bool GetConstCondition(const Instruction* c_inst, bool* cond_val) {
  476. bool cond_is_const;
  477. switch (c_inst->opcode()) {
  478. case SpvOpConstantFalse: {
  479. *cond_val = false;
  480. cond_is_const = true;
  481. } break;
  482. case SpvOpConstantTrue: {
  483. *cond_val = true;
  484. cond_is_const = true;
  485. } break;
  486. default: { cond_is_const = false; } break;
  487. }
  488. return cond_is_const;
  489. }
  490. // Simplifies |loop| assuming the instruction |to_version_insn| takes the
  491. // value |cst_value|. |block_range| is an iterator range returning the loop
  492. // basic blocks in a structured order (dominator first).
  493. // The function will ignore basic blocks returned by |block_range| if they
  494. // does not belong to the loop.
  495. // The set |dead_blocks| will contain all the dead basic blocks.
  496. //
  497. // Requirements:
  498. // - |loop| must be in the LCSSA form;
  499. // - |cst_value| must be constant or null (to represent the default target
  500. // of an OpSwitch).
  501. void SimplifyLoop(IteratorRange<UptrVectorIterator<BasicBlock>> block_range,
  502. Loop* loop, Instruction* to_version_insn,
  503. Instruction* cst_value,
  504. std::unordered_set<uint32_t>* dead_blocks) {
  505. CFG& cfg = *context_->cfg();
  506. analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
  507. std::function<bool(uint32_t)> ignore_node;
  508. ignore_node = [loop](uint32_t bb_id) { return !loop->IsInsideLoop(bb_id); };
  509. std::vector<std::pair<Instruction*, uint32_t>> use_list;
  510. def_use_mgr->ForEachUse(to_version_insn,
  511. [&use_list, &ignore_node, this](
  512. Instruction* inst, uint32_t operand_index) {
  513. BasicBlock* bb = context_->get_instr_block(inst);
  514. if (!bb || ignore_node(bb->id())) {
  515. // Out of the loop, the specialization does not
  516. // apply any more.
  517. return;
  518. }
  519. use_list.emplace_back(inst, operand_index);
  520. });
  521. // First pass: inject the specialized value into the loop (and only the
  522. // loop).
  523. for (auto use : use_list) {
  524. Instruction* inst = use.first;
  525. uint32_t operand_index = use.second;
  526. BasicBlock* bb = context_->get_instr_block(inst);
  527. // If it is not a branch, simply inject the value.
  528. if (!inst->IsBranch()) {
  529. // To also handle switch, cst_value can be nullptr: this case
  530. // means that we are looking to branch to the default target of
  531. // the switch. We don't actually know its value so we don't touch
  532. // it if it not a switch.
  533. if (cst_value) {
  534. inst->SetOperand(operand_index, {cst_value->result_id()});
  535. def_use_mgr->AnalyzeInstUse(inst);
  536. }
  537. }
  538. // The user is a branch, kill dead branches.
  539. uint32_t live_target = 0;
  540. std::unordered_set<uint32_t> dead_branches;
  541. switch (inst->opcode()) {
  542. case SpvOpBranchConditional: {
  543. assert(cst_value && "No constant value to specialize !");
  544. bool branch_cond = false;
  545. if (GetConstCondition(cst_value, &branch_cond)) {
  546. uint32_t true_label =
  547. inst->GetSingleWordInOperand(kBranchCondTrueLabIdInIdx);
  548. uint32_t false_label =
  549. inst->GetSingleWordInOperand(kBranchCondFalseLabIdInIdx);
  550. live_target = branch_cond ? true_label : false_label;
  551. uint32_t dead_target = !branch_cond ? true_label : false_label;
  552. cfg.RemoveEdge(bb->id(), dead_target);
  553. }
  554. break;
  555. }
  556. case SpvOpSwitch: {
  557. live_target = inst->GetSingleWordInOperand(1);
  558. if (cst_value) {
  559. if (!cst_value->IsConstant()) break;
  560. const Operand& cst = cst_value->GetInOperand(0);
  561. for (uint32_t i = 2; i < inst->NumInOperands(); i += 2) {
  562. const Operand& literal = inst->GetInOperand(i);
  563. if (literal == cst) {
  564. live_target = inst->GetSingleWordInOperand(i + 1);
  565. break;
  566. }
  567. }
  568. }
  569. for (uint32_t i = 1; i < inst->NumInOperands(); i += 2) {
  570. uint32_t id = inst->GetSingleWordInOperand(i);
  571. if (id != live_target) {
  572. cfg.RemoveEdge(bb->id(), id);
  573. }
  574. }
  575. }
  576. default:
  577. break;
  578. }
  579. if (live_target != 0) {
  580. // Check for the presence of the merge block.
  581. if (Instruction* merge = bb->GetMergeInst()) context_->KillInst(merge);
  582. context_->KillInst(&*bb->tail());
  583. InstructionBuilder builder(context_, bb,
  584. IRContext::kAnalysisDefUse |
  585. IRContext::kAnalysisInstrToBlockMapping);
  586. builder.AddBranch(live_target);
  587. }
  588. }
  589. // Go through the loop basic block and tag all blocks that are obviously
  590. // dead.
  591. std::unordered_set<uint32_t> visited;
  592. for (BasicBlock& bb : block_range) {
  593. if (ignore_node(bb.id())) continue;
  594. visited.insert(bb.id());
  595. // Check if this block is dead, if so tag it as dead otherwise patch phi
  596. // instructions.
  597. bool has_live_pred = false;
  598. for (uint32_t pid : cfg.preds(bb.id())) {
  599. if (!dead_blocks->count(pid)) {
  600. has_live_pred = true;
  601. break;
  602. }
  603. }
  604. if (!has_live_pred) {
  605. dead_blocks->insert(bb.id());
  606. const BasicBlock& cbb = bb;
  607. // Patch the phis for any back-edge.
  608. cbb.ForEachSuccessorLabel(
  609. [dead_blocks, &visited, &cfg, this](uint32_t id) {
  610. if (!visited.count(id) || dead_blocks->count(id)) return;
  611. BasicBlock* succ = cfg.block(id);
  612. PatchPhis(succ, *dead_blocks, false);
  613. });
  614. continue;
  615. }
  616. // Update the phi instructions, some incoming branch have/will disappear.
  617. PatchPhis(&bb, *dead_blocks, /* preserve_phi = */ false);
  618. }
  619. }
  620. // Returns true if the header is not reachable or tagged as dead or if we
  621. // never loop back.
  622. bool IsLoopDead(BasicBlock* header, BasicBlock* latch,
  623. const std::unordered_set<uint32_t>& dead_blocks) {
  624. if (!header || dead_blocks.count(header->id())) return true;
  625. if (!latch || dead_blocks.count(latch->id())) return true;
  626. for (uint32_t pid : context_->cfg()->preds(header->id())) {
  627. if (!dead_blocks.count(pid)) {
  628. // Seems reachable.
  629. return false;
  630. }
  631. }
  632. return true;
  633. }
  634. // Cleans the loop nest under |loop| and reflect changes to the loop
  635. // descriptor. This will kill all descriptors that represent dead loops.
  636. // If |loop_| is killed, it will be set to nullptr.
  637. // Any merge blocks that become unreachable will be added to
  638. // |unreachable_merges|.
  639. // The function returns the pointer to |loop| or nullptr if the loop was
  640. // killed.
  641. Loop* CleanLoopNest(Loop* loop,
  642. const std::unordered_set<uint32_t>& dead_blocks,
  643. std::unordered_set<uint32_t>* unreachable_merges) {
  644. // This represent the pair of dead loop and nearest alive parent (nullptr if
  645. // no parent).
  646. std::unordered_map<Loop*, Loop*> dead_loops;
  647. auto get_parent = [&dead_loops](Loop* l) -> Loop* {
  648. std::unordered_map<Loop*, Loop*>::iterator it = dead_loops.find(l);
  649. if (it != dead_loops.end()) return it->second;
  650. return nullptr;
  651. };
  652. bool is_main_loop_dead =
  653. IsLoopDead(loop->GetHeaderBlock(), loop->GetLatchBlock(), dead_blocks);
  654. if (is_main_loop_dead) {
  655. if (Instruction* merge = loop->GetHeaderBlock()->GetLoopMergeInst()) {
  656. context_->KillInst(merge);
  657. }
  658. dead_loops[loop] = loop->GetParent();
  659. } else {
  660. dead_loops[loop] = loop;
  661. }
  662. // For each loop, check if we killed it. If we did, find a suitable parent
  663. // for its children.
  664. for (Loop& sub_loop :
  665. make_range(++TreeDFIterator<Loop>(loop), TreeDFIterator<Loop>())) {
  666. if (IsLoopDead(sub_loop.GetHeaderBlock(), sub_loop.GetLatchBlock(),
  667. dead_blocks)) {
  668. if (Instruction* merge =
  669. sub_loop.GetHeaderBlock()->GetLoopMergeInst()) {
  670. context_->KillInst(merge);
  671. }
  672. dead_loops[&sub_loop] = get_parent(&sub_loop);
  673. } else {
  674. // The loop is alive, check if its merge block is dead, if it is, tag it
  675. // as required.
  676. if (sub_loop.GetMergeBlock()) {
  677. uint32_t merge_id = sub_loop.GetMergeBlock()->id();
  678. if (dead_blocks.count(merge_id)) {
  679. unreachable_merges->insert(sub_loop.GetMergeBlock()->id());
  680. }
  681. }
  682. }
  683. }
  684. if (!is_main_loop_dead) dead_loops.erase(loop);
  685. // Remove dead blocks from live loops.
  686. for (uint32_t bb_id : dead_blocks) {
  687. Loop* l = loop_desc_[bb_id];
  688. if (l) {
  689. l->RemoveBasicBlock(bb_id);
  690. loop_desc_.ForgetBasicBlock(bb_id);
  691. }
  692. }
  693. std::for_each(
  694. dead_loops.begin(), dead_loops.end(),
  695. [&loop,
  696. this](std::unordered_map<Loop*, Loop*>::iterator::reference it) {
  697. if (it.first == loop) loop = nullptr;
  698. loop_desc_.RemoveLoop(it.first);
  699. });
  700. return loop;
  701. }
  702. // Returns true if |var| is dynamically uniform.
  703. // Note: this is currently approximated as uniform.
  704. bool IsDynamicallyUniform(Instruction* var, const BasicBlock* entry,
  705. const DominatorTree& post_dom_tree) {
  706. assert(post_dom_tree.IsPostDominator());
  707. analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
  708. auto it = dynamically_uniform_.find(var->result_id());
  709. if (it != dynamically_uniform_.end()) return it->second;
  710. analysis::DecorationManager* dec_mgr = context_->get_decoration_mgr();
  711. bool& is_uniform = dynamically_uniform_[var->result_id()];
  712. is_uniform = false;
  713. dec_mgr->WhileEachDecoration(var->result_id(), SpvDecorationUniform,
  714. [&is_uniform](const Instruction&) {
  715. is_uniform = true;
  716. return false;
  717. });
  718. if (is_uniform) {
  719. return is_uniform;
  720. }
  721. BasicBlock* parent = context_->get_instr_block(var);
  722. if (!parent) {
  723. return is_uniform = true;
  724. }
  725. if (!post_dom_tree.Dominates(parent->id(), entry->id())) {
  726. return is_uniform = false;
  727. }
  728. if (var->opcode() == SpvOpLoad) {
  729. const uint32_t PtrTypeId =
  730. def_use_mgr->GetDef(var->GetSingleWordInOperand(0))->type_id();
  731. const Instruction* PtrTypeInst = def_use_mgr->GetDef(PtrTypeId);
  732. uint32_t storage_class =
  733. PtrTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx);
  734. if (storage_class != SpvStorageClassUniform &&
  735. storage_class != SpvStorageClassUniformConstant) {
  736. return is_uniform = false;
  737. }
  738. } else {
  739. if (!context_->IsCombinatorInstruction(var)) {
  740. return is_uniform = false;
  741. }
  742. }
  743. return is_uniform = var->WhileEachInId([entry, &post_dom_tree,
  744. this](const uint32_t* id) {
  745. return IsDynamicallyUniform(context_->get_def_use_mgr()->GetDef(*id),
  746. entry, post_dom_tree);
  747. });
  748. }
  749. // Returns true if |insn| is constant and dynamically uniform within the loop.
  750. bool IsConditionLoopInvariant(Instruction* insn) {
  751. assert(insn->IsBranch());
  752. assert(insn->opcode() != SpvOpBranch);
  753. analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
  754. Instruction* condition = def_use_mgr->GetDef(insn->GetOperand(0).words[0]);
  755. return !loop_->IsInsideLoop(condition) &&
  756. IsDynamicallyUniform(
  757. condition, function_->entry().get(),
  758. context_->GetPostDominatorAnalysis(function_)->GetDomTree());
  759. }
  760. };
  761. } // namespace
  762. Pass::Status LoopUnswitchPass::Process() {
  763. bool modified = false;
  764. Module* module = context()->module();
  765. // Process each function in the module
  766. for (Function& f : *module) {
  767. modified |= ProcessFunction(&f);
  768. }
  769. return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
  770. }
  771. bool LoopUnswitchPass::ProcessFunction(Function* f) {
  772. bool modified = false;
  773. std::unordered_set<Loop*> processed_loop;
  774. LoopDescriptor& loop_descriptor = *context()->GetLoopDescriptor(f);
  775. bool loop_changed = true;
  776. while (loop_changed) {
  777. loop_changed = false;
  778. for (Loop& loop :
  779. make_range(++TreeDFIterator<Loop>(loop_descriptor.GetDummyRootLoop()),
  780. TreeDFIterator<Loop>())) {
  781. if (processed_loop.count(&loop)) continue;
  782. processed_loop.insert(&loop);
  783. LoopUnswitch unswitcher(context(), f, &loop, &loop_descriptor);
  784. while (!unswitcher.WasLoopKilled() && unswitcher.CanUnswitchLoop()) {
  785. if (!loop.IsLCSSA()) {
  786. LoopUtils(context(), &loop).MakeLoopClosedSSA();
  787. }
  788. modified = true;
  789. loop_changed = true;
  790. unswitcher.PerformUnswitch();
  791. }
  792. if (loop_changed) break;
  793. }
  794. }
  795. return modified;
  796. }
  797. } // namespace opt
  798. } // namespace spvtools