if_conversion.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. // Copyright (c) 2018 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/if_conversion.h"
  15. #include <memory>
  16. #include <vector>
  17. #include "source/opt/value_number_table.h"
  18. namespace spvtools {
  19. namespace opt {
  20. Pass::Status IfConversion::Process() {
  21. if (!context()->get_feature_mgr()->HasCapability(spv::Capability::Shader)) {
  22. return Status::SuccessWithoutChange;
  23. }
  24. const ValueNumberTable& vn_table = *context()->GetValueNumberTable();
  25. bool modified = false;
  26. std::vector<Instruction*> to_kill;
  27. for (auto& func : *get_module()) {
  28. DominatorAnalysis* dominators = context()->GetDominatorAnalysis(&func);
  29. for (auto& block : func) {
  30. // Check if it is possible for |block| to have phis that can be
  31. // transformed.
  32. BasicBlock* common = nullptr;
  33. if (!CheckBlock(&block, dominators, &common)) continue;
  34. // Get an insertion point.
  35. auto iter = block.begin();
  36. while (iter != block.end() && iter->opcode() == spv::Op::OpPhi) {
  37. ++iter;
  38. }
  39. InstructionBuilder builder(
  40. context(), &*iter,
  41. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  42. block.ForEachPhiInst([this, &builder, &modified, &common, &to_kill,
  43. dominators, &block, &vn_table](Instruction* phi) {
  44. // This phi is not compatible, but subsequent phis might be.
  45. if (!CheckType(phi->type_id())) return;
  46. // We cannot transform cases where the phi is used by another phi in the
  47. // same block due to instruction ordering restrictions.
  48. // TODO(alan-baker): If all inappropriate uses could also be
  49. // transformed, we could still remove this phi.
  50. if (!CheckPhiUsers(phi, &block)) return;
  51. // Identify the incoming values associated with the true and false
  52. // branches. If |then_block| dominates |inc0| or if the true edge
  53. // branches straight to this block and |common| is |inc0|, then |inc0|
  54. // is on the true branch. Otherwise the |inc1| is on the true branch.
  55. BasicBlock* inc0 = GetIncomingBlock(phi, 0u);
  56. Instruction* branch = common->terminator();
  57. uint32_t condition = branch->GetSingleWordInOperand(0u);
  58. BasicBlock* then_block = GetBlock(branch->GetSingleWordInOperand(1u));
  59. Instruction* true_value = nullptr;
  60. Instruction* false_value = nullptr;
  61. if ((then_block == &block && inc0 == common) ||
  62. dominators->Dominates(then_block, inc0)) {
  63. true_value = GetIncomingValue(phi, 0u);
  64. false_value = GetIncomingValue(phi, 1u);
  65. } else {
  66. true_value = GetIncomingValue(phi, 1u);
  67. false_value = GetIncomingValue(phi, 0u);
  68. }
  69. BasicBlock* true_def_block = context()->get_instr_block(true_value);
  70. BasicBlock* false_def_block = context()->get_instr_block(false_value);
  71. uint32_t true_vn = vn_table.GetValueNumber(true_value);
  72. uint32_t false_vn = vn_table.GetValueNumber(false_value);
  73. if (true_vn != 0 && true_vn == false_vn) {
  74. Instruction* inst_to_use = nullptr;
  75. // Try to pick an instruction that is not in a side node. If we can't
  76. // pick either the true for false branch as long as they can be
  77. // legally moved.
  78. if (!true_def_block ||
  79. dominators->Dominates(true_def_block, &block)) {
  80. inst_to_use = true_value;
  81. } else if (!false_def_block ||
  82. dominators->Dominates(false_def_block, &block)) {
  83. inst_to_use = false_value;
  84. } else if (CanHoistInstruction(true_value, common, dominators)) {
  85. inst_to_use = true_value;
  86. } else if (CanHoistInstruction(false_value, common, dominators)) {
  87. inst_to_use = false_value;
  88. }
  89. if (inst_to_use != nullptr) {
  90. modified = true;
  91. HoistInstruction(inst_to_use, common, dominators);
  92. context()->KillNamesAndDecorates(phi);
  93. context()->ReplaceAllUsesWith(phi->result_id(),
  94. inst_to_use->result_id());
  95. }
  96. return;
  97. }
  98. // If either incoming value is defined in a block that does not dominate
  99. // this phi, then we cannot eliminate the phi with a select.
  100. // TODO(alan-baker): Perform code motion where it makes sense to enable
  101. // the transform in this case.
  102. if (true_def_block && !dominators->Dominates(true_def_block, &block))
  103. return;
  104. if (false_def_block && !dominators->Dominates(false_def_block, &block))
  105. return;
  106. analysis::Type* data_ty =
  107. context()->get_type_mgr()->GetType(true_value->type_id());
  108. if (analysis::Vector* vec_data_ty = data_ty->AsVector()) {
  109. condition = SplatCondition(vec_data_ty, condition, &builder);
  110. }
  111. Instruction* select = builder.AddSelect(phi->type_id(), condition,
  112. true_value->result_id(),
  113. false_value->result_id());
  114. context()->get_def_use_mgr()->AnalyzeInstDefUse(select);
  115. select->UpdateDebugInfoFrom(phi);
  116. context()->ReplaceAllUsesWith(phi->result_id(), select->result_id());
  117. to_kill.push_back(phi);
  118. modified = true;
  119. return;
  120. });
  121. }
  122. }
  123. for (auto inst : to_kill) {
  124. context()->KillInst(inst);
  125. }
  126. return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
  127. }
  128. bool IfConversion::CheckBlock(BasicBlock* block, DominatorAnalysis* dominators,
  129. BasicBlock** common) {
  130. const std::vector<uint32_t>& preds = cfg()->preds(block->id());
  131. // TODO(alan-baker): Extend to more than two predecessors
  132. if (preds.size() != 2) return false;
  133. BasicBlock* inc0 = context()->get_instr_block(preds[0]);
  134. if (dominators->Dominates(block, inc0)) return false;
  135. BasicBlock* inc1 = context()->get_instr_block(preds[1]);
  136. if (dominators->Dominates(block, inc1)) return false;
  137. if (inc0 == inc1) {
  138. // If the predecessor blocks are the same, then there is only 1 value for
  139. // the OpPhi. Other transformation should be able to simplify that.
  140. return false;
  141. }
  142. // All phis will have the same common dominator, so cache the result
  143. // for this block. If there is no common dominator, then we cannot transform
  144. // any phi in this basic block.
  145. *common = dominators->CommonDominator(inc0, inc1);
  146. if (!*common || cfg()->IsPseudoEntryBlock(*common)) return false;
  147. Instruction* branch = (*common)->terminator();
  148. if (branch->opcode() != spv::Op::OpBranchConditional) return false;
  149. auto merge = (*common)->GetMergeInst();
  150. if (!merge || merge->opcode() != spv::Op::OpSelectionMerge) return false;
  151. if (spv::SelectionControlMask(merge->GetSingleWordInOperand(1)) ==
  152. spv::SelectionControlMask::DontFlatten) {
  153. return false;
  154. }
  155. if ((*common)->MergeBlockIdIfAny() != block->id()) return false;
  156. return true;
  157. }
  158. bool IfConversion::CheckPhiUsers(Instruction* phi, BasicBlock* block) {
  159. return get_def_use_mgr()->WhileEachUser(
  160. phi, [block, this](Instruction* user) {
  161. if (user->opcode() == spv::Op::OpPhi &&
  162. context()->get_instr_block(user) == block)
  163. return false;
  164. return true;
  165. });
  166. }
  167. uint32_t IfConversion::SplatCondition(analysis::Vector* vec_data_ty,
  168. uint32_t cond,
  169. InstructionBuilder* builder) {
  170. // If the data inputs to OpSelect are vectors, the condition for
  171. // OpSelect must be a boolean vector with the same number of
  172. // components. So splat the condition for the branch into a vector
  173. // type.
  174. analysis::Bool bool_ty;
  175. analysis::Vector bool_vec_ty(&bool_ty, vec_data_ty->element_count());
  176. uint32_t bool_vec_id =
  177. context()->get_type_mgr()->GetTypeInstruction(&bool_vec_ty);
  178. std::vector<uint32_t> ids(vec_data_ty->element_count(), cond);
  179. return builder->AddCompositeConstruct(bool_vec_id, ids)->result_id();
  180. }
  181. bool IfConversion::CheckType(uint32_t id) {
  182. Instruction* type = get_def_use_mgr()->GetDef(id);
  183. spv::Op op = type->opcode();
  184. if (spvOpcodeIsScalarType(op) || op == spv::Op::OpTypePointer ||
  185. op == spv::Op::OpTypeVector)
  186. return true;
  187. return false;
  188. }
  189. BasicBlock* IfConversion::GetBlock(uint32_t id) {
  190. return context()->get_instr_block(get_def_use_mgr()->GetDef(id));
  191. }
  192. BasicBlock* IfConversion::GetIncomingBlock(Instruction* phi,
  193. uint32_t predecessor) {
  194. uint32_t in_index = 2 * predecessor + 1;
  195. return GetBlock(phi->GetSingleWordInOperand(in_index));
  196. }
  197. Instruction* IfConversion::GetIncomingValue(Instruction* phi,
  198. uint32_t predecessor) {
  199. uint32_t in_index = 2 * predecessor;
  200. return get_def_use_mgr()->GetDef(phi->GetSingleWordInOperand(in_index));
  201. }
  202. void IfConversion::HoistInstruction(Instruction* inst, BasicBlock* target_block,
  203. DominatorAnalysis* dominators) {
  204. BasicBlock* inst_block = context()->get_instr_block(inst);
  205. if (!inst_block) {
  206. // This is in the header, and dominates everything.
  207. return;
  208. }
  209. if (dominators->Dominates(inst_block, target_block)) {
  210. // Already in position. No work to do.
  211. return;
  212. }
  213. assert(inst->IsOpcodeCodeMotionSafe() &&
  214. "Trying to move an instruction that is not safe to move.");
  215. // First hoist all instructions it depends on.
  216. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  217. inst->ForEachInId(
  218. [this, target_block, def_use_mgr, dominators](uint32_t* id) {
  219. Instruction* operand_inst = def_use_mgr->GetDef(*id);
  220. HoistInstruction(operand_inst, target_block, dominators);
  221. });
  222. Instruction* insertion_pos = target_block->terminator();
  223. if ((insertion_pos)->PreviousNode()->opcode() == spv::Op::OpSelectionMerge) {
  224. insertion_pos = insertion_pos->PreviousNode();
  225. }
  226. inst->RemoveFromList();
  227. insertion_pos->InsertBefore(std::unique_ptr<Instruction>(inst));
  228. context()->set_instr_block(inst, target_block);
  229. }
  230. bool IfConversion::CanHoistInstruction(Instruction* inst,
  231. BasicBlock* target_block,
  232. DominatorAnalysis* dominators) {
  233. BasicBlock* inst_block = context()->get_instr_block(inst);
  234. if (!inst_block) {
  235. // This is in the header, and dominates everything.
  236. return true;
  237. }
  238. if (dominators->Dominates(inst_block, target_block)) {
  239. // Already in position. No work to do.
  240. return true;
  241. }
  242. if (!inst->IsOpcodeCodeMotionSafe()) {
  243. return false;
  244. }
  245. // Check all instruction |inst| depends on.
  246. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  247. return inst->WhileEachInId(
  248. [this, target_block, def_use_mgr, dominators](uint32_t* id) {
  249. Instruction* operand_inst = def_use_mgr->GetDef(*id);
  250. return CanHoistInstruction(operand_inst, target_block, dominators);
  251. });
  252. }
  253. } // namespace opt
  254. } // namespace spvtools