cfg.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. // Copyright (c) 2017 Google Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/cfg.h"
  15. #include <memory>
  16. #include <utility>
  17. #include "source/cfa.h"
  18. #include "source/opt/ir_builder.h"
  19. #include "source/opt/ir_context.h"
  20. #include "source/opt/module.h"
  21. namespace spvtools {
  22. namespace opt {
  23. namespace {
  24. using cbb_ptr = const opt::BasicBlock*;
  25. // Universal Limit of ResultID + 1
  26. constexpr int kMaxResultId = 0x400000;
  27. } // namespace
  28. CFG::CFG(Module* module)
  29. : module_(module),
  30. pseudo_entry_block_(std::unique_ptr<Instruction>(
  31. new Instruction(module->context(), spv::Op::OpLabel, 0, 0, {}))),
  32. pseudo_exit_block_(std::unique_ptr<Instruction>(new Instruction(
  33. module->context(), spv::Op::OpLabel, 0, kMaxResultId, {}))) {
  34. for (auto& fn : *module) {
  35. for (auto& blk : fn) {
  36. RegisterBlock(&blk);
  37. }
  38. }
  39. }
  40. void CFG::AddEdges(BasicBlock* blk) {
  41. uint32_t blk_id = blk->id();
  42. // Force the creation of an entry, not all basic block have predecessors
  43. // (such as the entry blocks and some unreachables).
  44. label2preds_[blk_id];
  45. const auto* const_blk = blk;
  46. const_blk->ForEachSuccessorLabel(
  47. [blk_id, this](const uint32_t succ_id) { AddEdge(blk_id, succ_id); });
  48. }
  49. void CFG::RemoveNonExistingEdges(uint32_t blk_id) {
  50. std::vector<uint32_t> updated_pred_list;
  51. for (uint32_t id : preds(blk_id)) {
  52. const BasicBlock* pred_blk = block(id);
  53. bool has_branch = false;
  54. pred_blk->ForEachSuccessorLabel([&has_branch, blk_id](uint32_t succ) {
  55. if (succ == blk_id) {
  56. has_branch = true;
  57. }
  58. });
  59. if (has_branch) updated_pred_list.push_back(id);
  60. }
  61. label2preds_.at(blk_id) = std::move(updated_pred_list);
  62. }
  63. void CFG::ComputeStructuredOrder(Function* func, BasicBlock* root,
  64. std::list<BasicBlock*>* order) {
  65. ComputeStructuredOrder(func, root, nullptr, order);
  66. }
  67. void CFG::ComputeStructuredOrder(Function* func, BasicBlock* root,
  68. BasicBlock* end,
  69. std::list<BasicBlock*>* order) {
  70. assert(module_->context()->get_feature_mgr()->HasCapability(
  71. spv::Capability::Shader) &&
  72. "This only works on structured control flow");
  73. // Compute structured successors and do DFS.
  74. ComputeStructuredSuccessors(func);
  75. auto ignore_block = [](cbb_ptr) {};
  76. auto terminal = [end](cbb_ptr bb) { return bb == end; };
  77. auto get_structured_successors = [this](const BasicBlock* b) {
  78. return &(block2structured_succs_[b]);
  79. };
  80. // TODO(greg-lunarg): Get rid of const_cast by making moving const
  81. // out of the cfa.h prototypes and into the invoking code.
  82. auto post_order = [&](cbb_ptr b) {
  83. order->push_front(const_cast<BasicBlock*>(b));
  84. };
  85. CFA<BasicBlock>::DepthFirstTraversal(root, get_structured_successors,
  86. ignore_block, post_order, terminal);
  87. }
  88. void CFG::ForEachBlockInPostOrder(BasicBlock* bb,
  89. const std::function<void(BasicBlock*)>& f) {
  90. std::vector<BasicBlock*> po;
  91. std::unordered_set<BasicBlock*> seen;
  92. ComputePostOrderTraversal(bb, &po, &seen);
  93. for (BasicBlock* current_bb : po) {
  94. if (!IsPseudoExitBlock(current_bb) && !IsPseudoEntryBlock(current_bb)) {
  95. f(current_bb);
  96. }
  97. }
  98. }
  99. void CFG::ForEachBlockInReversePostOrder(
  100. BasicBlock* bb, const std::function<void(BasicBlock*)>& f) {
  101. WhileEachBlockInReversePostOrder(bb, [f](BasicBlock* b) {
  102. f(b);
  103. return true;
  104. });
  105. }
  106. bool CFG::WhileEachBlockInReversePostOrder(
  107. BasicBlock* bb, const std::function<bool(BasicBlock*)>& f) {
  108. std::vector<BasicBlock*> po;
  109. std::unordered_set<BasicBlock*> seen;
  110. ComputePostOrderTraversal(bb, &po, &seen);
  111. for (auto current_bb = po.rbegin(); current_bb != po.rend(); ++current_bb) {
  112. if (!IsPseudoExitBlock(*current_bb) && !IsPseudoEntryBlock(*current_bb)) {
  113. if (!f(*current_bb)) {
  114. return false;
  115. }
  116. }
  117. }
  118. return true;
  119. }
  120. void CFG::ComputeStructuredSuccessors(Function* func) {
  121. block2structured_succs_.clear();
  122. for (auto& blk : *func) {
  123. // If no predecessors in function, make successor to pseudo entry.
  124. if (label2preds_[blk.id()].size() == 0)
  125. block2structured_succs_[&pseudo_entry_block_].push_back(&blk);
  126. // If header, make merge block first successor and continue block second
  127. // successor if there is one.
  128. uint32_t mbid = blk.MergeBlockIdIfAny();
  129. if (mbid != 0) {
  130. block2structured_succs_[&blk].push_back(block(mbid));
  131. uint32_t cbid = blk.ContinueBlockIdIfAny();
  132. if (cbid != 0) {
  133. block2structured_succs_[&blk].push_back(block(cbid));
  134. }
  135. }
  136. // Add true successors.
  137. const auto& const_blk = blk;
  138. const_blk.ForEachSuccessorLabel([&blk, this](const uint32_t sbid) {
  139. block2structured_succs_[&blk].push_back(block(sbid));
  140. });
  141. }
  142. }
  143. void CFG::ComputePostOrderTraversal(BasicBlock* bb,
  144. std::vector<BasicBlock*>* order,
  145. std::unordered_set<BasicBlock*>* seen) {
  146. std::vector<BasicBlock*> stack;
  147. stack.push_back(bb);
  148. while (!stack.empty()) {
  149. bb = stack.back();
  150. seen->insert(bb);
  151. static_cast<const BasicBlock*>(bb)->WhileEachSuccessorLabel(
  152. [&seen, &stack, this](const uint32_t sbid) {
  153. BasicBlock* succ_bb = id2block_[sbid];
  154. if (!seen->count(succ_bb)) {
  155. stack.push_back(succ_bb);
  156. return false;
  157. }
  158. return true;
  159. });
  160. if (stack.back() == bb) {
  161. order->push_back(bb);
  162. stack.pop_back();
  163. }
  164. }
  165. }
  166. BasicBlock* CFG::SplitLoopHeader(BasicBlock* bb) {
  167. assert(bb->GetLoopMergeInst() && "Expecting bb to be the header of a loop.");
  168. Function* fn = bb->GetParent();
  169. IRContext* context = module_->context();
  170. // Get the new header id up front. If we are out of ids, then we cannot split
  171. // the loop.
  172. uint32_t new_header_id = context->TakeNextId();
  173. if (new_header_id == 0) {
  174. return nullptr;
  175. }
  176. // Find the insertion point for the new bb.
  177. Function::iterator header_it = std::find_if(
  178. fn->begin(), fn->end(),
  179. [bb](BasicBlock& block_in_func) { return &block_in_func == bb; });
  180. assert(header_it != fn->end());
  181. const std::vector<uint32_t>& pred = preds(bb->id());
  182. // Find the back edge
  183. BasicBlock* latch_block = nullptr;
  184. Function::iterator latch_block_iter = header_it;
  185. for (; latch_block_iter != fn->end(); ++latch_block_iter) {
  186. // If blocks are in the proper order, then the only branch that appears
  187. // after the header is the latch.
  188. if (std::find(pred.begin(), pred.end(), latch_block_iter->id()) !=
  189. pred.end()) {
  190. break;
  191. }
  192. }
  193. assert(latch_block_iter != fn->end() && "Could not find the latch.");
  194. latch_block = &*latch_block_iter;
  195. RemoveSuccessorEdges(bb);
  196. // Create the new header bb basic bb.
  197. // Leave the phi instructions behind.
  198. auto iter = bb->begin();
  199. while (iter->opcode() == spv::Op::OpPhi) {
  200. ++iter;
  201. }
  202. BasicBlock* new_header = bb->SplitBasicBlock(context, new_header_id, iter);
  203. context->AnalyzeDefUse(new_header->GetLabelInst());
  204. // Update cfg
  205. RegisterBlock(new_header);
  206. // Update bb mappings.
  207. context->set_instr_block(new_header->GetLabelInst(), new_header);
  208. new_header->ForEachInst([new_header, context](Instruction* inst) {
  209. context->set_instr_block(inst, new_header);
  210. });
  211. // If |bb| was the latch block, the branch back to the header is not in
  212. // |new_header|.
  213. if (latch_block == bb) {
  214. if (new_header->ContinueBlockId() == bb->id()) {
  215. new_header->GetLoopMergeInst()->SetInOperand(1, {new_header_id});
  216. }
  217. latch_block = new_header;
  218. }
  219. // Adjust the OpPhi instructions as needed.
  220. bb->ForEachPhiInst([latch_block, bb, new_header, context](Instruction* phi) {
  221. std::vector<uint32_t> preheader_phi_ops;
  222. std::vector<Operand> header_phi_ops;
  223. // Identify where the original inputs to original OpPhi belong: header or
  224. // preheader.
  225. for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) {
  226. uint32_t def_id = phi->GetSingleWordInOperand(i);
  227. uint32_t branch_id = phi->GetSingleWordInOperand(i + 1);
  228. if (branch_id == latch_block->id()) {
  229. header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {def_id}});
  230. header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {branch_id}});
  231. } else {
  232. preheader_phi_ops.push_back(def_id);
  233. preheader_phi_ops.push_back(branch_id);
  234. }
  235. }
  236. // Create a phi instruction if and only if the preheader_phi_ops has more
  237. // than one pair.
  238. if (preheader_phi_ops.size() > 2) {
  239. InstructionBuilder builder(
  240. context, &*bb->begin(),
  241. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  242. Instruction* new_phi = builder.AddPhi(phi->type_id(), preheader_phi_ops);
  243. // Add the OpPhi to the header bb.
  244. header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {new_phi->result_id()}});
  245. header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {bb->id()}});
  246. } else {
  247. // An OpPhi with a single entry is just a copy. In this case use the same
  248. // instruction in the new header.
  249. header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {preheader_phi_ops[0]}});
  250. header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {bb->id()}});
  251. }
  252. phi->RemoveFromList();
  253. std::unique_ptr<Instruction> phi_owner(phi);
  254. phi->SetInOperands(std::move(header_phi_ops));
  255. new_header->begin()->InsertBefore(std::move(phi_owner));
  256. context->set_instr_block(phi, new_header);
  257. context->AnalyzeUses(phi);
  258. });
  259. // Add a branch to the new header.
  260. InstructionBuilder branch_builder(
  261. context, bb,
  262. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  263. bb->AddInstruction(
  264. MakeUnique<Instruction>(context, spv::Op::OpBranch, 0, 0,
  265. std::initializer_list<Operand>{
  266. {SPV_OPERAND_TYPE_ID, {new_header->id()}}}));
  267. context->AnalyzeUses(bb->terminator());
  268. context->set_instr_block(bb->terminator(), bb);
  269. label2preds_[new_header->id()].push_back(bb->id());
  270. // Update the latch to branch to the new header.
  271. latch_block->ForEachSuccessorLabel([bb, new_header_id](uint32_t* id) {
  272. if (*id == bb->id()) {
  273. *id = new_header_id;
  274. }
  275. });
  276. Instruction* latch_branch = latch_block->terminator();
  277. context->AnalyzeUses(latch_branch);
  278. label2preds_[new_header->id()].push_back(latch_block->id());
  279. auto& block_preds = label2preds_[bb->id()];
  280. auto latch_pos =
  281. std::find(block_preds.begin(), block_preds.end(), latch_block->id());
  282. assert(latch_pos != block_preds.end() && "The cfg was invalid.");
  283. block_preds.erase(latch_pos);
  284. // Update the loop descriptors
  285. if (context->AreAnalysesValid(IRContext::kAnalysisLoopAnalysis)) {
  286. LoopDescriptor* loop_desc = context->GetLoopDescriptor(bb->GetParent());
  287. Loop* loop = (*loop_desc)[bb->id()];
  288. loop->AddBasicBlock(new_header_id);
  289. loop->SetHeaderBlock(new_header);
  290. loop_desc->SetBasicBlockToLoop(new_header_id, loop);
  291. loop->RemoveBasicBlock(bb->id());
  292. loop->SetPreHeaderBlock(bb);
  293. Loop* parent_loop = loop->GetParent();
  294. if (parent_loop != nullptr) {
  295. parent_loop->AddBasicBlock(bb->id());
  296. loop_desc->SetBasicBlockToLoop(bb->id(), parent_loop);
  297. } else {
  298. loop_desc->SetBasicBlockToLoop(bb->id(), nullptr);
  299. }
  300. }
  301. return new_header;
  302. }
  303. } // namespace opt
  304. } // namespace spvtools