cfg.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. // Copyright (c) 2017 Google Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/cfg.h"
  15. #include <memory>
  16. #include <utility>
  17. #include "source/cfa.h"
  18. #include "source/opt/ir_builder.h"
  19. #include "source/opt/ir_context.h"
  20. #include "source/opt/module.h"
  21. namespace spvtools {
  22. namespace opt {
  23. namespace {
  24. using cbb_ptr = const opt::BasicBlock*;
  25. // Universal Limit of ResultID + 1
  26. const int kMaxResultId = 0x400000;
  27. } // namespace
  28. CFG::CFG(Module* module)
  29. : module_(module),
  30. pseudo_entry_block_(std::unique_ptr<Instruction>(
  31. new Instruction(module->context(), SpvOpLabel, 0, 0, {}))),
  32. pseudo_exit_block_(std::unique_ptr<Instruction>(new Instruction(
  33. module->context(), SpvOpLabel, 0, kMaxResultId, {}))) {
  34. for (auto& fn : *module) {
  35. for (auto& blk : fn) {
  36. RegisterBlock(&blk);
  37. }
  38. }
  39. }
  40. void CFG::AddEdges(BasicBlock* blk) {
  41. uint32_t blk_id = blk->id();
  42. // Force the creation of an entry, not all basic block have predecessors
  43. // (such as the entry blocks and some unreachables).
  44. label2preds_[blk_id];
  45. const auto* const_blk = blk;
  46. const_blk->ForEachSuccessorLabel(
  47. [blk_id, this](const uint32_t succ_id) { AddEdge(blk_id, succ_id); });
  48. }
  49. void CFG::RemoveNonExistingEdges(uint32_t blk_id) {
  50. std::vector<uint32_t> updated_pred_list;
  51. for (uint32_t id : preds(blk_id)) {
  52. const BasicBlock* pred_blk = block(id);
  53. bool has_branch = false;
  54. pred_blk->ForEachSuccessorLabel([&has_branch, blk_id](uint32_t succ) {
  55. if (succ == blk_id) {
  56. has_branch = true;
  57. }
  58. });
  59. if (has_branch) updated_pred_list.push_back(id);
  60. }
  61. label2preds_.at(blk_id) = std::move(updated_pred_list);
  62. }
  63. void CFG::ComputeStructuredOrder(Function* func, BasicBlock* root,
  64. std::list<BasicBlock*>* order) {
  65. ComputeStructuredOrder(func, root, nullptr, order);
  66. }
  67. void CFG::ComputeStructuredOrder(Function* func, BasicBlock* root,
  68. BasicBlock* end,
  69. std::list<BasicBlock*>* order) {
  70. assert(module_->context()->get_feature_mgr()->HasCapability(
  71. SpvCapabilityShader) &&
  72. "This only works on structured control flow");
  73. // Compute structured successors and do DFS.
  74. ComputeStructuredSuccessors(func);
  75. auto ignore_block = [](cbb_ptr) {};
  76. auto ignore_edge = [](cbb_ptr, cbb_ptr) {};
  77. auto terminal = [end](cbb_ptr bb) { return bb == end; };
  78. auto get_structured_successors = [this](const BasicBlock* b) {
  79. return &(block2structured_succs_[b]);
  80. };
  81. // TODO(greg-lunarg): Get rid of const_cast by making moving const
  82. // out of the cfa.h prototypes and into the invoking code.
  83. auto post_order = [&](cbb_ptr b) {
  84. order->push_front(const_cast<BasicBlock*>(b));
  85. };
  86. CFA<BasicBlock>::DepthFirstTraversal(root, get_structured_successors,
  87. ignore_block, post_order, ignore_edge,
  88. terminal);
  89. }
  90. void CFG::ForEachBlockInPostOrder(BasicBlock* bb,
  91. const std::function<void(BasicBlock*)>& f) {
  92. std::vector<BasicBlock*> po;
  93. std::unordered_set<BasicBlock*> seen;
  94. ComputePostOrderTraversal(bb, &po, &seen);
  95. for (BasicBlock* current_bb : po) {
  96. if (!IsPseudoExitBlock(current_bb) && !IsPseudoEntryBlock(current_bb)) {
  97. f(current_bb);
  98. }
  99. }
  100. }
  101. void CFG::ForEachBlockInReversePostOrder(
  102. BasicBlock* bb, const std::function<void(BasicBlock*)>& f) {
  103. WhileEachBlockInReversePostOrder(bb, [f](BasicBlock* b) {
  104. f(b);
  105. return true;
  106. });
  107. }
  108. bool CFG::WhileEachBlockInReversePostOrder(
  109. BasicBlock* bb, const std::function<bool(BasicBlock*)>& f) {
  110. std::vector<BasicBlock*> po;
  111. std::unordered_set<BasicBlock*> seen;
  112. ComputePostOrderTraversal(bb, &po, &seen);
  113. for (auto current_bb = po.rbegin(); current_bb != po.rend(); ++current_bb) {
  114. if (!IsPseudoExitBlock(*current_bb) && !IsPseudoEntryBlock(*current_bb)) {
  115. if (!f(*current_bb)) {
  116. return false;
  117. }
  118. }
  119. }
  120. return true;
  121. }
  122. void CFG::ComputeStructuredSuccessors(Function* func) {
  123. block2structured_succs_.clear();
  124. for (auto& blk : *func) {
  125. // If no predecessors in function, make successor to pseudo entry.
  126. if (label2preds_[blk.id()].size() == 0)
  127. block2structured_succs_[&pseudo_entry_block_].push_back(&blk);
  128. // If header, make merge block first successor and continue block second
  129. // successor if there is one.
  130. uint32_t mbid = blk.MergeBlockIdIfAny();
  131. if (mbid != 0) {
  132. block2structured_succs_[&blk].push_back(block(mbid));
  133. uint32_t cbid = blk.ContinueBlockIdIfAny();
  134. if (cbid != 0) {
  135. block2structured_succs_[&blk].push_back(block(cbid));
  136. }
  137. }
  138. // Add true successors.
  139. const auto& const_blk = blk;
  140. const_blk.ForEachSuccessorLabel([&blk, this](const uint32_t sbid) {
  141. block2structured_succs_[&blk].push_back(block(sbid));
  142. });
  143. }
  144. }
  145. void CFG::ComputePostOrderTraversal(BasicBlock* bb,
  146. std::vector<BasicBlock*>* order,
  147. std::unordered_set<BasicBlock*>* seen) {
  148. std::vector<BasicBlock*> stack;
  149. stack.push_back(bb);
  150. while (!stack.empty()) {
  151. bb = stack.back();
  152. seen->insert(bb);
  153. static_cast<const BasicBlock*>(bb)->WhileEachSuccessorLabel(
  154. [&seen, &stack, this](const uint32_t sbid) {
  155. BasicBlock* succ_bb = id2block_[sbid];
  156. if (!seen->count(succ_bb)) {
  157. stack.push_back(succ_bb);
  158. return false;
  159. }
  160. return true;
  161. });
  162. if (stack.back() == bb) {
  163. order->push_back(bb);
  164. stack.pop_back();
  165. }
  166. }
  167. }
  168. BasicBlock* CFG::SplitLoopHeader(BasicBlock* bb) {
  169. assert(bb->GetLoopMergeInst() && "Expecting bb to be the header of a loop.");
  170. Function* fn = bb->GetParent();
  171. IRContext* context = module_->context();
  172. // Get the new header id up front. If we are out of ids, then we cannot split
  173. // the loop.
  174. uint32_t new_header_id = context->TakeNextId();
  175. if (new_header_id == 0) {
  176. return nullptr;
  177. }
  178. // Find the insertion point for the new bb.
  179. Function::iterator header_it = std::find_if(
  180. fn->begin(), fn->end(),
  181. [bb](BasicBlock& block_in_func) { return &block_in_func == bb; });
  182. assert(header_it != fn->end());
  183. const std::vector<uint32_t>& pred = preds(bb->id());
  184. // Find the back edge
  185. BasicBlock* latch_block = nullptr;
  186. Function::iterator latch_block_iter = header_it;
  187. for (; latch_block_iter != fn->end(); ++latch_block_iter) {
  188. // If blocks are in the proper order, then the only branch that appears
  189. // after the header is the latch.
  190. if (std::find(pred.begin(), pred.end(), latch_block_iter->id()) !=
  191. pred.end()) {
  192. break;
  193. }
  194. }
  195. assert(latch_block_iter != fn->end() && "Could not find the latch.");
  196. latch_block = &*latch_block_iter;
  197. RemoveSuccessorEdges(bb);
  198. // Create the new header bb basic bb.
  199. // Leave the phi instructions behind.
  200. auto iter = bb->begin();
  201. while (iter->opcode() == SpvOpPhi) {
  202. ++iter;
  203. }
  204. BasicBlock* new_header = bb->SplitBasicBlock(context, new_header_id, iter);
  205. context->AnalyzeDefUse(new_header->GetLabelInst());
  206. // Update cfg
  207. RegisterBlock(new_header);
  208. // Update bb mappings.
  209. context->set_instr_block(new_header->GetLabelInst(), new_header);
  210. new_header->ForEachInst([new_header, context](Instruction* inst) {
  211. context->set_instr_block(inst, new_header);
  212. });
  213. // If |bb| was the latch block, the branch back to the header is not in
  214. // |new_header|.
  215. if (latch_block == bb) {
  216. if (new_header->ContinueBlockId() == bb->id()) {
  217. new_header->GetLoopMergeInst()->SetInOperand(1, {new_header_id});
  218. }
  219. latch_block = new_header;
  220. }
  221. // Adjust the OpPhi instructions as needed.
  222. bb->ForEachPhiInst([latch_block, bb, new_header, context](Instruction* phi) {
  223. std::vector<uint32_t> preheader_phi_ops;
  224. std::vector<Operand> header_phi_ops;
  225. // Identify where the original inputs to original OpPhi belong: header or
  226. // preheader.
  227. for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) {
  228. uint32_t def_id = phi->GetSingleWordInOperand(i);
  229. uint32_t branch_id = phi->GetSingleWordInOperand(i + 1);
  230. if (branch_id == latch_block->id()) {
  231. header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {def_id}});
  232. header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {branch_id}});
  233. } else {
  234. preheader_phi_ops.push_back(def_id);
  235. preheader_phi_ops.push_back(branch_id);
  236. }
  237. }
  238. // Create a phi instruction if and only if the preheader_phi_ops has more
  239. // than one pair.
  240. if (preheader_phi_ops.size() > 2) {
  241. InstructionBuilder builder(
  242. context, &*bb->begin(),
  243. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  244. Instruction* new_phi = builder.AddPhi(phi->type_id(), preheader_phi_ops);
  245. // Add the OpPhi to the header bb.
  246. header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {new_phi->result_id()}});
  247. header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {bb->id()}});
  248. } else {
  249. // An OpPhi with a single entry is just a copy. In this case use the same
  250. // instruction in the new header.
  251. header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {preheader_phi_ops[0]}});
  252. header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {bb->id()}});
  253. }
  254. phi->RemoveFromList();
  255. std::unique_ptr<Instruction> phi_owner(phi);
  256. phi->SetInOperands(std::move(header_phi_ops));
  257. new_header->begin()->InsertBefore(std::move(phi_owner));
  258. context->set_instr_block(phi, new_header);
  259. context->AnalyzeUses(phi);
  260. });
  261. // Add a branch to the new header.
  262. InstructionBuilder branch_builder(
  263. context, bb,
  264. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  265. bb->AddInstruction(
  266. MakeUnique<Instruction>(context, SpvOpBranch, 0, 0,
  267. std::initializer_list<Operand>{
  268. {SPV_OPERAND_TYPE_ID, {new_header->id()}}}));
  269. context->AnalyzeUses(bb->terminator());
  270. context->set_instr_block(bb->terminator(), bb);
  271. label2preds_[new_header->id()].push_back(bb->id());
  272. // Update the latch to branch to the new header.
  273. latch_block->ForEachSuccessorLabel([bb, new_header_id](uint32_t* id) {
  274. if (*id == bb->id()) {
  275. *id = new_header_id;
  276. }
  277. });
  278. Instruction* latch_branch = latch_block->terminator();
  279. context->AnalyzeUses(latch_branch);
  280. label2preds_[new_header->id()].push_back(latch_block->id());
  281. auto& block_preds = label2preds_[bb->id()];
  282. auto latch_pos =
  283. std::find(block_preds.begin(), block_preds.end(), latch_block->id());
  284. assert(latch_pos != block_preds.end() && "The cfg was invalid.");
  285. block_preds.erase(latch_pos);
  286. // Update the loop descriptors
  287. if (context->AreAnalysesValid(IRContext::kAnalysisLoopAnalysis)) {
  288. LoopDescriptor* loop_desc = context->GetLoopDescriptor(bb->GetParent());
  289. Loop* loop = (*loop_desc)[bb->id()];
  290. loop->AddBasicBlock(new_header_id);
  291. loop->SetHeaderBlock(new_header);
  292. loop_desc->SetBasicBlockToLoop(new_header_id, loop);
  293. loop->RemoveBasicBlock(bb->id());
  294. loop->SetPreHeaderBlock(bb);
  295. Loop* parent_loop = loop->GetParent();
  296. if (parent_loop != nullptr) {
  297. parent_loop->AddBasicBlock(bb->id());
  298. loop_desc->SetBasicBlockToLoop(bb->id(), parent_loop);
  299. } else {
  300. loop_desc->SetBasicBlockToLoop(bb->id(), nullptr);
  301. }
  302. }
  303. return new_header;
  304. }
  305. } // namespace opt
  306. } // namespace spvtools