code_sink.cpp 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. // Copyright (c) 2019 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "code_sink.h"
  15. #include <set>
  16. #include <vector>
  17. #include "source/opt/instruction.h"
  18. #include "source/opt/ir_builder.h"
  19. #include "source/opt/ir_context.h"
  20. #include "source/util/bit_vector.h"
  21. namespace spvtools {
  22. namespace opt {
  23. Pass::Status CodeSinkingPass::Process() {
  24. bool modified = false;
  25. for (Function& function : *get_module()) {
  26. cfg()->ForEachBlockInPostOrder(function.entry().get(),
  27. [&modified, this](BasicBlock* bb) {
  28. if (SinkInstructionsInBB(bb)) {
  29. modified = true;
  30. }
  31. });
  32. }
  33. return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
  34. }
  35. bool CodeSinkingPass::SinkInstructionsInBB(BasicBlock* bb) {
  36. bool modified = false;
  37. for (auto inst = bb->rbegin(); inst != bb->rend(); ++inst) {
  38. if (SinkInstruction(&*inst)) {
  39. inst = bb->rbegin();
  40. modified = true;
  41. }
  42. }
  43. return modified;
  44. }
  45. bool CodeSinkingPass::SinkInstruction(Instruction* inst) {
  46. if (inst->opcode() != SpvOpLoad && inst->opcode() != SpvOpAccessChain) {
  47. return false;
  48. }
  49. if (ReferencesMutableMemory(inst)) {
  50. return false;
  51. }
  52. if (BasicBlock* target_bb = FindNewBasicBlockFor(inst)) {
  53. Instruction* pos = &*target_bb->begin();
  54. while (pos->opcode() == SpvOpPhi) {
  55. pos = pos->NextNode();
  56. }
  57. inst->InsertBefore(pos);
  58. context()->set_instr_block(inst, target_bb);
  59. return true;
  60. }
  61. return false;
  62. }
  63. BasicBlock* CodeSinkingPass::FindNewBasicBlockFor(Instruction* inst) {
  64. assert(inst->result_id() != 0 && "Instruction should have a result.");
  65. BasicBlock* original_bb = context()->get_instr_block(inst);
  66. BasicBlock* bb = original_bb;
  67. std::unordered_set<uint32_t> bbs_with_uses;
  68. get_def_use_mgr()->ForEachUse(
  69. inst, [&bbs_with_uses, this](Instruction* use, uint32_t idx) {
  70. if (use->opcode() != SpvOpPhi) {
  71. bbs_with_uses.insert(context()->get_instr_block(use)->id());
  72. } else {
  73. bbs_with_uses.insert(use->GetSingleWordOperand(idx + 1));
  74. }
  75. });
  76. while (true) {
  77. // If |inst| is used in |bb|, then |inst| cannot be moved any further.
  78. if (bbs_with_uses.count(bb->id())) {
  79. break;
  80. }
  81. // If |bb| has one successor (succ_bb), and |bb| is the only predecessor
  82. // of succ_bb, then |inst| can be moved to succ_bb. If succ_bb, has move
  83. // then one predecessor, then moving |inst| into succ_bb could cause it to
  84. // be executed more often, so the search has to stop.
  85. if (bb->terminator()->opcode() == SpvOpBranch) {
  86. uint32_t succ_bb_id = bb->terminator()->GetSingleWordInOperand(0);
  87. if (cfg()->preds(succ_bb_id).size() == 1) {
  88. bb = context()->get_instr_block(succ_bb_id);
  89. continue;
  90. } else {
  91. break;
  92. }
  93. }
  94. // The remaining checks need to know the merge node. If there is no merge
  95. // instruction or an OpLoopMerge, then it is a break or continue. We could
  96. // figure it out, but not worth doing it now.
  97. Instruction* merge_inst = bb->GetMergeInst();
  98. if (merge_inst == nullptr || merge_inst->opcode() != SpvOpSelectionMerge) {
  99. break;
  100. }
  101. // Check all of the successors of |bb| it see which lead to a use of |inst|
  102. // before reaching the merge node.
  103. bool used_in_multiple_blocks = false;
  104. uint32_t bb_used_in = 0;
  105. bb->ForEachSuccessorLabel([this, bb, &bb_used_in, &used_in_multiple_blocks,
  106. &bbs_with_uses](uint32_t* succ_bb_id) {
  107. if (IntersectsPath(*succ_bb_id, bb->MergeBlockIdIfAny(), bbs_with_uses)) {
  108. if (bb_used_in == 0) {
  109. bb_used_in = *succ_bb_id;
  110. } else {
  111. used_in_multiple_blocks = true;
  112. }
  113. }
  114. });
  115. // If more than one successor, which is not the merge block, uses |inst|
  116. // then we have to leave |inst| in bb because there is none of the
  117. // successors dominate all uses of |inst|.
  118. if (used_in_multiple_blocks) {
  119. break;
  120. }
  121. if (bb_used_in == 0) {
  122. // If |inst| is not used before reaching the merge node, then we can move
  123. // |inst| to the merge node.
  124. bb = context()->get_instr_block(bb->MergeBlockIdIfAny());
  125. } else {
  126. // If the only successor that leads to a used of |inst| has more than 1
  127. // predecessor, then moving |inst| could cause it to be executed more
  128. // often, so we cannot move it.
  129. if (cfg()->preds(bb_used_in).size() != 1) {
  130. break;
  131. }
  132. // If |inst| is used after the merge block, then |bb_used_in| does not
  133. // dominate all of the uses. So we cannot move |inst| any further.
  134. if (IntersectsPath(bb->MergeBlockIdIfAny(), original_bb->id(),
  135. bbs_with_uses)) {
  136. break;
  137. }
  138. // Otherwise, |bb_used_in| dominates all uses, so move |inst| into that
  139. // block.
  140. bb = context()->get_instr_block(bb_used_in);
  141. }
  142. continue;
  143. }
  144. return (bb != original_bb ? bb : nullptr);
  145. }
  146. bool CodeSinkingPass::ReferencesMutableMemory(Instruction* inst) {
  147. if (!inst->IsLoad()) {
  148. return false;
  149. }
  150. Instruction* base_ptr = inst->GetBaseAddress();
  151. if (base_ptr->opcode() != SpvOpVariable) {
  152. return true;
  153. }
  154. if (base_ptr->IsReadOnlyVariable()) {
  155. return false;
  156. }
  157. if (HasUniformMemorySync()) {
  158. return true;
  159. }
  160. if (base_ptr->GetSingleWordInOperand(0) != SpvStorageClassUniform) {
  161. return true;
  162. }
  163. return HasPossibleStore(base_ptr);
  164. }
  165. bool CodeSinkingPass::HasUniformMemorySync() {
  166. if (checked_for_uniform_sync_) {
  167. return has_uniform_sync_;
  168. }
  169. bool has_sync = false;
  170. get_module()->ForEachInst([this, &has_sync](Instruction* inst) {
  171. switch (inst->opcode()) {
  172. case SpvOpMemoryBarrier: {
  173. uint32_t mem_semantics_id = inst->GetSingleWordInOperand(1);
  174. if (IsSyncOnUniform(mem_semantics_id)) {
  175. has_sync = true;
  176. }
  177. break;
  178. }
  179. case SpvOpControlBarrier:
  180. case SpvOpAtomicLoad:
  181. case SpvOpAtomicStore:
  182. case SpvOpAtomicExchange:
  183. case SpvOpAtomicIIncrement:
  184. case SpvOpAtomicIDecrement:
  185. case SpvOpAtomicIAdd:
  186. case SpvOpAtomicISub:
  187. case SpvOpAtomicSMin:
  188. case SpvOpAtomicUMin:
  189. case SpvOpAtomicSMax:
  190. case SpvOpAtomicUMax:
  191. case SpvOpAtomicAnd:
  192. case SpvOpAtomicOr:
  193. case SpvOpAtomicXor:
  194. case SpvOpAtomicFlagTestAndSet:
  195. case SpvOpAtomicFlagClear: {
  196. uint32_t mem_semantics_id = inst->GetSingleWordInOperand(2);
  197. if (IsSyncOnUniform(mem_semantics_id)) {
  198. has_sync = true;
  199. }
  200. break;
  201. }
  202. case SpvOpAtomicCompareExchange:
  203. case SpvOpAtomicCompareExchangeWeak:
  204. if (IsSyncOnUniform(inst->GetSingleWordInOperand(2)) ||
  205. IsSyncOnUniform(inst->GetSingleWordInOperand(3))) {
  206. has_sync = true;
  207. }
  208. break;
  209. default:
  210. break;
  211. }
  212. });
  213. has_uniform_sync_ = has_sync;
  214. return has_sync;
  215. }
  216. bool CodeSinkingPass::IsSyncOnUniform(uint32_t mem_semantics_id) const {
  217. const analysis::Constant* mem_semantics_const =
  218. context()->get_constant_mgr()->FindDeclaredConstant(mem_semantics_id);
  219. assert(mem_semantics_const != nullptr &&
  220. "Expecting memory semantics id to be a constant.");
  221. assert(mem_semantics_const->AsIntConstant() &&
  222. "Memory semantics should be an integer.");
  223. uint32_t mem_semantics_int = mem_semantics_const->GetU32();
  224. // If it does not affect uniform memory, then it is does not apply to uniform
  225. // memory.
  226. if ((mem_semantics_int & SpvMemorySemanticsUniformMemoryMask) == 0) {
  227. return false;
  228. }
  229. // Check if there is an acquire or release. If so not, this it does not add
  230. // any memory constraints.
  231. return (mem_semantics_int & (SpvMemorySemanticsAcquireMask |
  232. SpvMemorySemanticsAcquireReleaseMask |
  233. SpvMemorySemanticsReleaseMask)) != 0;
  234. }
  235. bool CodeSinkingPass::HasPossibleStore(Instruction* var_inst) {
  236. assert(var_inst->opcode() == SpvOpVariable ||
  237. var_inst->opcode() == SpvOpAccessChain ||
  238. var_inst->opcode() == SpvOpPtrAccessChain);
  239. return get_def_use_mgr()->WhileEachUser(var_inst, [this](Instruction* use) {
  240. switch (use->opcode()) {
  241. case SpvOpStore:
  242. return true;
  243. case SpvOpAccessChain:
  244. case SpvOpPtrAccessChain:
  245. return HasPossibleStore(use);
  246. default:
  247. return false;
  248. }
  249. });
  250. }
  251. bool CodeSinkingPass::IntersectsPath(uint32_t start, uint32_t end,
  252. const std::unordered_set<uint32_t>& set) {
  253. std::vector<uint32_t> worklist;
  254. worklist.push_back(start);
  255. std::unordered_set<uint32_t> already_done;
  256. already_done.insert(start);
  257. while (!worklist.empty()) {
  258. BasicBlock* bb = context()->get_instr_block(worklist.back());
  259. worklist.pop_back();
  260. if (bb->id() == end) {
  261. continue;
  262. }
  263. if (set.count(bb->id())) {
  264. return true;
  265. }
  266. bb->ForEachSuccessorLabel([&already_done, &worklist](uint32_t* succ_bb_id) {
  267. if (already_done.insert(*succ_bb_id).second) {
  268. worklist.push_back(*succ_bb_id);
  269. }
  270. });
  271. }
  272. return false;
  273. }
  274. // namespace opt
  275. } // namespace opt
  276. } // namespace spvtools