cfg.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. // Copyright (c) 2017 Google Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/cfg.h"
  15. #include <memory>
  16. #include <utility>
  17. #include "source/cfa.h"
  18. #include "source/opt/ir_builder.h"
  19. #include "source/opt/ir_context.h"
  20. #include "source/opt/module.h"
  21. namespace spvtools {
  22. namespace opt {
  23. namespace {
  24. using cbb_ptr = const opt::BasicBlock*;
  25. // Universal Limit of ResultID + 1
  26. const int kMaxResultId = 0x400000;
  27. } // namespace
  28. CFG::CFG(Module* module)
  29. : module_(module),
  30. pseudo_entry_block_(std::unique_ptr<Instruction>(
  31. new Instruction(module->context(), SpvOpLabel, 0, 0, {}))),
  32. pseudo_exit_block_(std::unique_ptr<Instruction>(new Instruction(
  33. module->context(), SpvOpLabel, 0, kMaxResultId, {}))) {
  34. for (auto& fn : *module) {
  35. for (auto& blk : fn) {
  36. RegisterBlock(&blk);
  37. }
  38. }
  39. }
  40. void CFG::AddEdges(BasicBlock* blk) {
  41. uint32_t blk_id = blk->id();
  42. // Force the creation of an entry, not all basic block have predecessors
  43. // (such as the entry blocks and some unreachables).
  44. label2preds_[blk_id];
  45. const auto* const_blk = blk;
  46. const_blk->ForEachSuccessorLabel(
  47. [blk_id, this](const uint32_t succ_id) { AddEdge(blk_id, succ_id); });
  48. }
  49. void CFG::RemoveNonExistingEdges(uint32_t blk_id) {
  50. std::vector<uint32_t> updated_pred_list;
  51. for (uint32_t id : preds(blk_id)) {
  52. const BasicBlock* pred_blk = block(id);
  53. bool has_branch = false;
  54. pred_blk->ForEachSuccessorLabel([&has_branch, blk_id](uint32_t succ) {
  55. if (succ == blk_id) {
  56. has_branch = true;
  57. }
  58. });
  59. if (has_branch) updated_pred_list.push_back(id);
  60. }
  61. label2preds_.at(blk_id) = std::move(updated_pred_list);
  62. }
  63. void CFG::ComputeStructuredOrder(Function* func, BasicBlock* root,
  64. std::list<BasicBlock*>* order) {
  65. assert(module_->context()->get_feature_mgr()->HasCapability(
  66. SpvCapabilityShader) &&
  67. "This only works on structured control flow");
  68. // Compute structured successors and do DFS.
  69. ComputeStructuredSuccessors(func);
  70. auto ignore_block = [](cbb_ptr) {};
  71. auto ignore_edge = [](cbb_ptr, cbb_ptr) {};
  72. auto get_structured_successors = [this](const BasicBlock* b) {
  73. return &(block2structured_succs_[b]);
  74. };
  75. // TODO(greg-lunarg): Get rid of const_cast by making moving const
  76. // out of the cfa.h prototypes and into the invoking code.
  77. auto post_order = [&](cbb_ptr b) {
  78. order->push_front(const_cast<BasicBlock*>(b));
  79. };
  80. CFA<BasicBlock>::DepthFirstTraversal(root, get_structured_successors,
  81. ignore_block, post_order, ignore_edge);
  82. }
  83. void CFG::ForEachBlockInPostOrder(BasicBlock* bb,
  84. const std::function<void(BasicBlock*)>& f) {
  85. std::vector<BasicBlock*> po;
  86. std::unordered_set<BasicBlock*> seen;
  87. ComputePostOrderTraversal(bb, &po, &seen);
  88. for (BasicBlock* current_bb : po) {
  89. if (!IsPseudoExitBlock(current_bb) && !IsPseudoEntryBlock(current_bb)) {
  90. f(current_bb);
  91. }
  92. }
  93. }
  94. void CFG::ForEachBlockInReversePostOrder(
  95. BasicBlock* bb, const std::function<void(BasicBlock*)>& f) {
  96. std::vector<BasicBlock*> po;
  97. std::unordered_set<BasicBlock*> seen;
  98. ComputePostOrderTraversal(bb, &po, &seen);
  99. for (auto current_bb = po.rbegin(); current_bb != po.rend(); ++current_bb) {
  100. if (!IsPseudoExitBlock(*current_bb) && !IsPseudoEntryBlock(*current_bb)) {
  101. f(*current_bb);
  102. }
  103. }
  104. }
  105. void CFG::ComputeStructuredSuccessors(Function* func) {
  106. block2structured_succs_.clear();
  107. for (auto& blk : *func) {
  108. // If no predecessors in function, make successor to pseudo entry.
  109. if (label2preds_[blk.id()].size() == 0)
  110. block2structured_succs_[&pseudo_entry_block_].push_back(&blk);
  111. // If header, make merge block first successor and continue block second
  112. // successor if there is one.
  113. uint32_t mbid = blk.MergeBlockIdIfAny();
  114. if (mbid != 0) {
  115. block2structured_succs_[&blk].push_back(block(mbid));
  116. uint32_t cbid = blk.ContinueBlockIdIfAny();
  117. if (cbid != 0) {
  118. block2structured_succs_[&blk].push_back(block(cbid));
  119. }
  120. }
  121. // Add true successors.
  122. const auto& const_blk = blk;
  123. const_blk.ForEachSuccessorLabel([&blk, this](const uint32_t sbid) {
  124. block2structured_succs_[&blk].push_back(block(sbid));
  125. });
  126. }
  127. }
  128. void CFG::ComputePostOrderTraversal(BasicBlock* bb,
  129. std::vector<BasicBlock*>* order,
  130. std::unordered_set<BasicBlock*>* seen) {
  131. seen->insert(bb);
  132. static_cast<const BasicBlock*>(bb)->ForEachSuccessorLabel(
  133. [&order, &seen, this](const uint32_t sbid) {
  134. BasicBlock* succ_bb = id2block_[sbid];
  135. if (!seen->count(succ_bb)) {
  136. ComputePostOrderTraversal(succ_bb, order, seen);
  137. }
  138. });
  139. order->push_back(bb);
  140. }
  141. BasicBlock* CFG::SplitLoopHeader(BasicBlock* bb) {
  142. assert(bb->GetLoopMergeInst() && "Expecting bb to be the header of a loop.");
  143. Function* fn = bb->GetParent();
  144. IRContext* context = module_->context();
  145. // Find the insertion point for the new bb.
  146. Function::iterator header_it = std::find_if(
  147. fn->begin(), fn->end(),
  148. [bb](BasicBlock& block_in_func) { return &block_in_func == bb; });
  149. assert(header_it != fn->end());
  150. const std::vector<uint32_t>& pred = preds(bb->id());
  151. // Find the back edge
  152. BasicBlock* latch_block = nullptr;
  153. Function::iterator latch_block_iter = header_it;
  154. while (++latch_block_iter != fn->end()) {
  155. // If blocks are in the proper order, then the only branch that appears
  156. // after the header is the latch.
  157. if (std::find(pred.begin(), pred.end(), latch_block_iter->id()) !=
  158. pred.end()) {
  159. break;
  160. }
  161. }
  162. assert(latch_block_iter != fn->end() && "Could not find the latch.");
  163. latch_block = &*latch_block_iter;
  164. RemoveSuccessorEdges(bb);
  165. // Create the new header bb basic bb.
  166. // Leave the phi instructions behind.
  167. auto iter = bb->begin();
  168. while (iter->opcode() == SpvOpPhi) {
  169. ++iter;
  170. }
  171. std::unique_ptr<BasicBlock> newBlock(
  172. bb->SplitBasicBlock(context, context->TakeNextId(), iter));
  173. // Insert the new bb in the correct position
  174. auto insert_pos = header_it;
  175. ++insert_pos;
  176. BasicBlock* new_header = &*insert_pos.InsertBefore(std::move(newBlock));
  177. new_header->SetParent(fn);
  178. uint32_t new_header_id = new_header->id();
  179. context->AnalyzeDefUse(new_header->GetLabelInst());
  180. // Update cfg
  181. RegisterBlock(new_header);
  182. // Update bb mappings.
  183. context->set_instr_block(new_header->GetLabelInst(), new_header);
  184. new_header->ForEachInst([new_header, context](Instruction* inst) {
  185. context->set_instr_block(inst, new_header);
  186. });
  187. // Adjust the OpPhi instructions as needed.
  188. bb->ForEachPhiInst([latch_block, bb, new_header, context](Instruction* phi) {
  189. std::vector<uint32_t> preheader_phi_ops;
  190. std::vector<Operand> header_phi_ops;
  191. // Identify where the original inputs to original OpPhi belong: header or
  192. // preheader.
  193. for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) {
  194. uint32_t def_id = phi->GetSingleWordInOperand(i);
  195. uint32_t branch_id = phi->GetSingleWordInOperand(i + 1);
  196. if (branch_id == latch_block->id()) {
  197. header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {def_id}});
  198. header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {branch_id}});
  199. } else {
  200. preheader_phi_ops.push_back(def_id);
  201. preheader_phi_ops.push_back(branch_id);
  202. }
  203. }
  204. // Create a phi instruction if and only if the preheader_phi_ops has more
  205. // than one pair.
  206. if (preheader_phi_ops.size() > 2) {
  207. InstructionBuilder builder(
  208. context, &*bb->begin(),
  209. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  210. Instruction* new_phi = builder.AddPhi(phi->type_id(), preheader_phi_ops);
  211. // Add the OpPhi to the header bb.
  212. header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {new_phi->result_id()}});
  213. header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {bb->id()}});
  214. } else {
  215. // An OpPhi with a single entry is just a copy. In this case use the same
  216. // instruction in the new header.
  217. header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {preheader_phi_ops[0]}});
  218. header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {bb->id()}});
  219. }
  220. phi->RemoveFromList();
  221. std::unique_ptr<Instruction> phi_owner(phi);
  222. phi->SetInOperands(std::move(header_phi_ops));
  223. new_header->begin()->InsertBefore(std::move(phi_owner));
  224. context->set_instr_block(phi, new_header);
  225. context->AnalyzeUses(phi);
  226. });
  227. // Add a branch to the new header.
  228. InstructionBuilder branch_builder(
  229. context, bb,
  230. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  231. bb->AddInstruction(
  232. MakeUnique<Instruction>(context, SpvOpBranch, 0, 0,
  233. std::initializer_list<Operand>{
  234. {SPV_OPERAND_TYPE_ID, {new_header->id()}}}));
  235. context->AnalyzeUses(bb->terminator());
  236. context->set_instr_block(bb->terminator(), bb);
  237. label2preds_[new_header->id()].push_back(bb->id());
  238. // Update the latch to branch to the new header.
  239. latch_block->ForEachSuccessorLabel([bb, new_header_id](uint32_t* id) {
  240. if (*id == bb->id()) {
  241. *id = new_header_id;
  242. }
  243. });
  244. Instruction* latch_branch = latch_block->terminator();
  245. context->AnalyzeUses(latch_branch);
  246. label2preds_[new_header->id()].push_back(latch_block->id());
  247. auto& block_preds = label2preds_[bb->id()];
  248. auto latch_pos =
  249. std::find(block_preds.begin(), block_preds.end(), latch_block->id());
  250. assert(latch_pos != block_preds.end() && "The cfg was invalid.");
  251. block_preds.erase(latch_pos);
  252. // Update the loop descriptors
  253. if (context->AreAnalysesValid(IRContext::kAnalysisLoopAnalysis)) {
  254. LoopDescriptor* loop_desc = context->GetLoopDescriptor(bb->GetParent());
  255. Loop* loop = (*loop_desc)[bb->id()];
  256. loop->AddBasicBlock(new_header_id);
  257. loop->SetHeaderBlock(new_header);
  258. loop_desc->SetBasicBlockToLoop(new_header_id, loop);
  259. loop->RemoveBasicBlock(bb->id());
  260. loop->SetPreHeaderBlock(bb);
  261. Loop* parent_loop = loop->GetParent();
  262. if (parent_loop != nullptr) {
  263. parent_loop->AddBasicBlock(bb->id());
  264. loop_desc->SetBasicBlockToLoop(bb->id(), parent_loop);
  265. } else {
  266. loop_desc->SetBasicBlockToLoop(bb->id(), nullptr);
  267. }
  268. }
  269. return new_header;
  270. }
  271. } // namespace opt
  272. } // namespace spvtools