cfg.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. // Copyright (c) 2017 Google Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/cfg.h"
  15. #include <memory>
  16. #include <utility>
  17. #include "source/cfa.h"
  18. #include "source/opt/ir_builder.h"
  19. #include "source/opt/ir_context.h"
  20. #include "source/opt/module.h"
  21. namespace spvtools {
  22. namespace opt {
  23. namespace {
  24. using cbb_ptr = const opt::BasicBlock*;
  25. // Universal Limit of ResultID + 1
  26. const int kMaxResultId = 0x400000;
  27. } // namespace
  28. CFG::CFG(Module* module)
  29. : module_(module),
  30. pseudo_entry_block_(std::unique_ptr<Instruction>(
  31. new Instruction(module->context(), SpvOpLabel, 0, 0, {}))),
  32. pseudo_exit_block_(std::unique_ptr<Instruction>(new Instruction(
  33. module->context(), SpvOpLabel, 0, kMaxResultId, {}))) {
  34. for (auto& fn : *module) {
  35. for (auto& blk : fn) {
  36. RegisterBlock(&blk);
  37. }
  38. }
  39. }
  40. void CFG::AddEdges(BasicBlock* blk) {
  41. uint32_t blk_id = blk->id();
  42. // Force the creation of an entry, not all basic block have predecessors
  43. // (such as the entry blocks and some unreachables).
  44. label2preds_[blk_id];
  45. const auto* const_blk = blk;
  46. const_blk->ForEachSuccessorLabel(
  47. [blk_id, this](const uint32_t succ_id) { AddEdge(blk_id, succ_id); });
  48. }
  49. void CFG::RemoveNonExistingEdges(uint32_t blk_id) {
  50. std::vector<uint32_t> updated_pred_list;
  51. for (uint32_t id : preds(blk_id)) {
  52. const BasicBlock* pred_blk = block(id);
  53. bool has_branch = false;
  54. pred_blk->ForEachSuccessorLabel([&has_branch, blk_id](uint32_t succ) {
  55. if (succ == blk_id) {
  56. has_branch = true;
  57. }
  58. });
  59. if (has_branch) updated_pred_list.push_back(id);
  60. }
  61. label2preds_.at(blk_id) = std::move(updated_pred_list);
  62. }
  63. void CFG::ComputeStructuredOrder(Function* func, BasicBlock* root,
  64. std::list<BasicBlock*>* order) {
  65. assert(module_->context()->get_feature_mgr()->HasCapability(
  66. SpvCapabilityShader) &&
  67. "This only works on structured control flow");
  68. // Compute structured successors and do DFS.
  69. ComputeStructuredSuccessors(func);
  70. auto ignore_block = [](cbb_ptr) {};
  71. auto ignore_edge = [](cbb_ptr, cbb_ptr) {};
  72. auto get_structured_successors = [this](const BasicBlock* b) {
  73. return &(block2structured_succs_[b]);
  74. };
  75. // TODO(greg-lunarg): Get rid of const_cast by making moving const
  76. // out of the cfa.h prototypes and into the invoking code.
  77. auto post_order = [&](cbb_ptr b) {
  78. order->push_front(const_cast<BasicBlock*>(b));
  79. };
  80. CFA<BasicBlock>::DepthFirstTraversal(root, get_structured_successors,
  81. ignore_block, post_order, ignore_edge);
  82. }
  83. void CFG::ForEachBlockInPostOrder(BasicBlock* bb,
  84. const std::function<void(BasicBlock*)>& f) {
  85. std::vector<BasicBlock*> po;
  86. std::unordered_set<BasicBlock*> seen;
  87. ComputePostOrderTraversal(bb, &po, &seen);
  88. for (BasicBlock* current_bb : po) {
  89. if (!IsPseudoExitBlock(current_bb) && !IsPseudoEntryBlock(current_bb)) {
  90. f(current_bb);
  91. }
  92. }
  93. }
  94. void CFG::ForEachBlockInReversePostOrder(
  95. BasicBlock* bb, const std::function<void(BasicBlock*)>& f) {
  96. WhileEachBlockInReversePostOrder(bb, [f](BasicBlock* b) {
  97. f(b);
  98. return true;
  99. });
  100. }
  101. bool CFG::WhileEachBlockInReversePostOrder(
  102. BasicBlock* bb, const std::function<bool(BasicBlock*)>& f) {
  103. std::vector<BasicBlock*> po;
  104. std::unordered_set<BasicBlock*> seen;
  105. ComputePostOrderTraversal(bb, &po, &seen);
  106. for (auto current_bb = po.rbegin(); current_bb != po.rend(); ++current_bb) {
  107. if (!IsPseudoExitBlock(*current_bb) && !IsPseudoEntryBlock(*current_bb)) {
  108. if (!f(*current_bb)) {
  109. return false;
  110. }
  111. }
  112. }
  113. return true;
  114. }
  115. void CFG::ComputeStructuredSuccessors(Function* func) {
  116. block2structured_succs_.clear();
  117. for (auto& blk : *func) {
  118. // If no predecessors in function, make successor to pseudo entry.
  119. if (label2preds_[blk.id()].size() == 0)
  120. block2structured_succs_[&pseudo_entry_block_].push_back(&blk);
  121. // If header, make merge block first successor and continue block second
  122. // successor if there is one.
  123. uint32_t mbid = blk.MergeBlockIdIfAny();
  124. if (mbid != 0) {
  125. block2structured_succs_[&blk].push_back(block(mbid));
  126. uint32_t cbid = blk.ContinueBlockIdIfAny();
  127. if (cbid != 0) {
  128. block2structured_succs_[&blk].push_back(block(cbid));
  129. }
  130. }
  131. // Add true successors.
  132. const auto& const_blk = blk;
  133. const_blk.ForEachSuccessorLabel([&blk, this](const uint32_t sbid) {
  134. block2structured_succs_[&blk].push_back(block(sbid));
  135. });
  136. }
  137. }
  138. void CFG::ComputePostOrderTraversal(BasicBlock* bb,
  139. std::vector<BasicBlock*>* order,
  140. std::unordered_set<BasicBlock*>* seen) {
  141. std::vector<BasicBlock*> stack;
  142. stack.push_back(bb);
  143. while (!stack.empty()) {
  144. bb = stack.back();
  145. seen->insert(bb);
  146. static_cast<const BasicBlock*>(bb)->WhileEachSuccessorLabel(
  147. [&seen, &stack, this](const uint32_t sbid) {
  148. BasicBlock* succ_bb = id2block_[sbid];
  149. if (!seen->count(succ_bb)) {
  150. stack.push_back(succ_bb);
  151. return false;
  152. }
  153. return true;
  154. });
  155. if (stack.back() == bb) {
  156. order->push_back(bb);
  157. stack.pop_back();
  158. }
  159. }
  160. }
  161. BasicBlock* CFG::SplitLoopHeader(BasicBlock* bb) {
  162. assert(bb->GetLoopMergeInst() && "Expecting bb to be the header of a loop.");
  163. Function* fn = bb->GetParent();
  164. IRContext* context = module_->context();
  165. // Get the new header id up front. If we are out of ids, then we cannot split
  166. // the loop.
  167. uint32_t new_header_id = context->TakeNextId();
  168. if (new_header_id == 0) {
  169. return nullptr;
  170. }
  171. // Find the insertion point for the new bb.
  172. Function::iterator header_it = std::find_if(
  173. fn->begin(), fn->end(),
  174. [bb](BasicBlock& block_in_func) { return &block_in_func == bb; });
  175. assert(header_it != fn->end());
  176. const std::vector<uint32_t>& pred = preds(bb->id());
  177. // Find the back edge
  178. BasicBlock* latch_block = nullptr;
  179. Function::iterator latch_block_iter = header_it;
  180. while (++latch_block_iter != fn->end()) {
  181. // If blocks are in the proper order, then the only branch that appears
  182. // after the header is the latch.
  183. if (std::find(pred.begin(), pred.end(), latch_block_iter->id()) !=
  184. pred.end()) {
  185. break;
  186. }
  187. }
  188. assert(latch_block_iter != fn->end() && "Could not find the latch.");
  189. latch_block = &*latch_block_iter;
  190. RemoveSuccessorEdges(bb);
  191. // Create the new header bb basic bb.
  192. // Leave the phi instructions behind.
  193. auto iter = bb->begin();
  194. while (iter->opcode() == SpvOpPhi) {
  195. ++iter;
  196. }
  197. BasicBlock* new_header = bb->SplitBasicBlock(context, new_header_id, iter);
  198. context->AnalyzeDefUse(new_header->GetLabelInst());
  199. // Update cfg
  200. RegisterBlock(new_header);
  201. // Update bb mappings.
  202. context->set_instr_block(new_header->GetLabelInst(), new_header);
  203. new_header->ForEachInst([new_header, context](Instruction* inst) {
  204. context->set_instr_block(inst, new_header);
  205. });
  206. // Adjust the OpPhi instructions as needed.
  207. bb->ForEachPhiInst([latch_block, bb, new_header, context](Instruction* phi) {
  208. std::vector<uint32_t> preheader_phi_ops;
  209. std::vector<Operand> header_phi_ops;
  210. // Identify where the original inputs to original OpPhi belong: header or
  211. // preheader.
  212. for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) {
  213. uint32_t def_id = phi->GetSingleWordInOperand(i);
  214. uint32_t branch_id = phi->GetSingleWordInOperand(i + 1);
  215. if (branch_id == latch_block->id()) {
  216. header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {def_id}});
  217. header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {branch_id}});
  218. } else {
  219. preheader_phi_ops.push_back(def_id);
  220. preheader_phi_ops.push_back(branch_id);
  221. }
  222. }
  223. // Create a phi instruction if and only if the preheader_phi_ops has more
  224. // than one pair.
  225. if (preheader_phi_ops.size() > 2) {
  226. InstructionBuilder builder(
  227. context, &*bb->begin(),
  228. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  229. Instruction* new_phi = builder.AddPhi(phi->type_id(), preheader_phi_ops);
  230. // Add the OpPhi to the header bb.
  231. header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {new_phi->result_id()}});
  232. header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {bb->id()}});
  233. } else {
  234. // An OpPhi with a single entry is just a copy. In this case use the same
  235. // instruction in the new header.
  236. header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {preheader_phi_ops[0]}});
  237. header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {bb->id()}});
  238. }
  239. phi->RemoveFromList();
  240. std::unique_ptr<Instruction> phi_owner(phi);
  241. phi->SetInOperands(std::move(header_phi_ops));
  242. new_header->begin()->InsertBefore(std::move(phi_owner));
  243. context->set_instr_block(phi, new_header);
  244. context->AnalyzeUses(phi);
  245. });
  246. // Add a branch to the new header.
  247. InstructionBuilder branch_builder(
  248. context, bb,
  249. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  250. bb->AddInstruction(
  251. MakeUnique<Instruction>(context, SpvOpBranch, 0, 0,
  252. std::initializer_list<Operand>{
  253. {SPV_OPERAND_TYPE_ID, {new_header->id()}}}));
  254. context->AnalyzeUses(bb->terminator());
  255. context->set_instr_block(bb->terminator(), bb);
  256. label2preds_[new_header->id()].push_back(bb->id());
  257. // Update the latch to branch to the new header.
  258. latch_block->ForEachSuccessorLabel([bb, new_header_id](uint32_t* id) {
  259. if (*id == bb->id()) {
  260. *id = new_header_id;
  261. }
  262. });
  263. Instruction* latch_branch = latch_block->terminator();
  264. context->AnalyzeUses(latch_branch);
  265. label2preds_[new_header->id()].push_back(latch_block->id());
  266. auto& block_preds = label2preds_[bb->id()];
  267. auto latch_pos =
  268. std::find(block_preds.begin(), block_preds.end(), latch_block->id());
  269. assert(latch_pos != block_preds.end() && "The cfg was invalid.");
  270. block_preds.erase(latch_pos);
  271. // Update the loop descriptors
  272. if (context->AreAnalysesValid(IRContext::kAnalysisLoopAnalysis)) {
  273. LoopDescriptor* loop_desc = context->GetLoopDescriptor(bb->GetParent());
  274. Loop* loop = (*loop_desc)[bb->id()];
  275. loop->AddBasicBlock(new_header_id);
  276. loop->SetHeaderBlock(new_header);
  277. loop_desc->SetBasicBlockToLoop(new_header_id, loop);
  278. loop->RemoveBasicBlock(bb->id());
  279. loop->SetPreHeaderBlock(bb);
  280. Loop* parent_loop = loop->GetParent();
  281. if (parent_loop != nullptr) {
  282. parent_loop->AddBasicBlock(bb->id());
  283. loop_desc->SetBasicBlockToLoop(bb->id(), parent_loop);
  284. } else {
  285. loop_desc->SetBasicBlockToLoop(bb->id(), nullptr);
  286. }
  287. }
  288. return new_header;
  289. }
  290. } // namespace opt
  291. } // namespace spvtools