code_sink.cpp 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. // Copyright (c) 2019 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "code_sink.h"
  15. #include <set>
  16. #include <vector>
  17. #include "source/opt/instruction.h"
  18. #include "source/opt/ir_builder.h"
  19. #include "source/opt/ir_context.h"
  20. #include "source/util/bit_vector.h"
  21. namespace spvtools {
  22. namespace opt {
  23. Pass::Status CodeSinkingPass::Process() {
  24. bool modified = false;
  25. for (Function& function : *get_module()) {
  26. cfg()->ForEachBlockInPostOrder(function.entry().get(),
  27. [&modified, this](BasicBlock* bb) {
  28. if (SinkInstructionsInBB(bb)) {
  29. modified = true;
  30. }
  31. });
  32. }
  33. return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
  34. }
  35. bool CodeSinkingPass::SinkInstructionsInBB(BasicBlock* bb) {
  36. bool modified = false;
  37. for (auto inst = bb->rbegin(); inst != bb->rend(); ++inst) {
  38. if (SinkInstruction(&*inst)) {
  39. inst = bb->rbegin();
  40. modified = true;
  41. }
  42. }
  43. return modified;
  44. }
  45. bool CodeSinkingPass::SinkInstruction(Instruction* inst) {
  46. if (inst->opcode() != SpvOpLoad && inst->opcode() != SpvOpAccessChain) {
  47. return false;
  48. }
  49. if (ReferencesMutableMemory(inst)) {
  50. return false;
  51. }
  52. if (BasicBlock* target_bb = FindNewBasicBlockFor(inst)) {
  53. Instruction* pos = &*target_bb->begin();
  54. while (pos->opcode() == SpvOpPhi) {
  55. pos = pos->NextNode();
  56. }
  57. inst->InsertBefore(pos);
  58. context()->set_instr_block(inst, target_bb);
  59. return true;
  60. }
  61. return false;
  62. }
  63. BasicBlock* CodeSinkingPass::FindNewBasicBlockFor(Instruction* inst) {
  64. assert(inst->result_id() != 0 && "Instruction should have a result.");
  65. BasicBlock* original_bb = context()->get_instr_block(inst);
  66. BasicBlock* bb = original_bb;
  67. std::unordered_set<uint32_t> bbs_with_uses;
  68. get_def_use_mgr()->ForEachUse(
  69. inst, [&bbs_with_uses, this](Instruction* use, uint32_t idx) {
  70. if (use->opcode() != SpvOpPhi) {
  71. BasicBlock* use_bb = context()->get_instr_block(use);
  72. if (use_bb) {
  73. bbs_with_uses.insert(use_bb->id());
  74. }
  75. } else {
  76. bbs_with_uses.insert(use->GetSingleWordOperand(idx + 1));
  77. }
  78. });
  79. while (true) {
  80. // If |inst| is used in |bb|, then |inst| cannot be moved any further.
  81. if (bbs_with_uses.count(bb->id())) {
  82. break;
  83. }
  84. // If |bb| has one successor (succ_bb), and |bb| is the only predecessor
  85. // of succ_bb, then |inst| can be moved to succ_bb. If succ_bb, has move
  86. // then one predecessor, then moving |inst| into succ_bb could cause it to
  87. // be executed more often, so the search has to stop.
  88. if (bb->terminator()->opcode() == SpvOpBranch) {
  89. uint32_t succ_bb_id = bb->terminator()->GetSingleWordInOperand(0);
  90. if (cfg()->preds(succ_bb_id).size() == 1) {
  91. bb = context()->get_instr_block(succ_bb_id);
  92. continue;
  93. } else {
  94. break;
  95. }
  96. }
  97. // The remaining checks need to know the merge node. If there is no merge
  98. // instruction or an OpLoopMerge, then it is a break or continue. We could
  99. // figure it out, but not worth doing it now.
  100. Instruction* merge_inst = bb->GetMergeInst();
  101. if (merge_inst == nullptr || merge_inst->opcode() != SpvOpSelectionMerge) {
  102. break;
  103. }
  104. // Check all of the successors of |bb| it see which lead to a use of |inst|
  105. // before reaching the merge node.
  106. bool used_in_multiple_blocks = false;
  107. uint32_t bb_used_in = 0;
  108. bb->ForEachSuccessorLabel([this, bb, &bb_used_in, &used_in_multiple_blocks,
  109. &bbs_with_uses](uint32_t* succ_bb_id) {
  110. if (IntersectsPath(*succ_bb_id, bb->MergeBlockIdIfAny(), bbs_with_uses)) {
  111. if (bb_used_in == 0) {
  112. bb_used_in = *succ_bb_id;
  113. } else {
  114. used_in_multiple_blocks = true;
  115. }
  116. }
  117. });
  118. // If more than one successor, which is not the merge block, uses |inst|
  119. // then we have to leave |inst| in bb because there is none of the
  120. // successors dominate all uses of |inst|.
  121. if (used_in_multiple_blocks) {
  122. break;
  123. }
  124. if (bb_used_in == 0) {
  125. // If |inst| is not used before reaching the merge node, then we can move
  126. // |inst| to the merge node.
  127. bb = context()->get_instr_block(bb->MergeBlockIdIfAny());
  128. } else {
  129. // If the only successor that leads to a used of |inst| has more than 1
  130. // predecessor, then moving |inst| could cause it to be executed more
  131. // often, so we cannot move it.
  132. if (cfg()->preds(bb_used_in).size() != 1) {
  133. break;
  134. }
  135. // If |inst| is used after the merge block, then |bb_used_in| does not
  136. // dominate all of the uses. So we cannot move |inst| any further.
  137. if (IntersectsPath(bb->MergeBlockIdIfAny(), original_bb->id(),
  138. bbs_with_uses)) {
  139. break;
  140. }
  141. // Otherwise, |bb_used_in| dominates all uses, so move |inst| into that
  142. // block.
  143. bb = context()->get_instr_block(bb_used_in);
  144. }
  145. continue;
  146. }
  147. return (bb != original_bb ? bb : nullptr);
  148. }
  149. bool CodeSinkingPass::ReferencesMutableMemory(Instruction* inst) {
  150. if (!inst->IsLoad()) {
  151. return false;
  152. }
  153. Instruction* base_ptr = inst->GetBaseAddress();
  154. if (base_ptr->opcode() != SpvOpVariable) {
  155. return true;
  156. }
  157. if (base_ptr->IsReadOnlyPointer()) {
  158. return false;
  159. }
  160. if (HasUniformMemorySync()) {
  161. return true;
  162. }
  163. if (base_ptr->GetSingleWordInOperand(0) != SpvStorageClassUniform) {
  164. return true;
  165. }
  166. return HasPossibleStore(base_ptr);
  167. }
  168. bool CodeSinkingPass::HasUniformMemorySync() {
  169. if (checked_for_uniform_sync_) {
  170. return has_uniform_sync_;
  171. }
  172. bool has_sync = false;
  173. get_module()->ForEachInst([this, &has_sync](Instruction* inst) {
  174. switch (inst->opcode()) {
  175. case SpvOpMemoryBarrier: {
  176. uint32_t mem_semantics_id = inst->GetSingleWordInOperand(1);
  177. if (IsSyncOnUniform(mem_semantics_id)) {
  178. has_sync = true;
  179. }
  180. break;
  181. }
  182. case SpvOpControlBarrier:
  183. case SpvOpAtomicLoad:
  184. case SpvOpAtomicStore:
  185. case SpvOpAtomicExchange:
  186. case SpvOpAtomicIIncrement:
  187. case SpvOpAtomicIDecrement:
  188. case SpvOpAtomicIAdd:
  189. case SpvOpAtomicFAddEXT:
  190. case SpvOpAtomicISub:
  191. case SpvOpAtomicSMin:
  192. case SpvOpAtomicUMin:
  193. case SpvOpAtomicSMax:
  194. case SpvOpAtomicUMax:
  195. case SpvOpAtomicAnd:
  196. case SpvOpAtomicOr:
  197. case SpvOpAtomicXor:
  198. case SpvOpAtomicFlagTestAndSet:
  199. case SpvOpAtomicFlagClear: {
  200. uint32_t mem_semantics_id = inst->GetSingleWordInOperand(2);
  201. if (IsSyncOnUniform(mem_semantics_id)) {
  202. has_sync = true;
  203. }
  204. break;
  205. }
  206. case SpvOpAtomicCompareExchange:
  207. case SpvOpAtomicCompareExchangeWeak:
  208. if (IsSyncOnUniform(inst->GetSingleWordInOperand(2)) ||
  209. IsSyncOnUniform(inst->GetSingleWordInOperand(3))) {
  210. has_sync = true;
  211. }
  212. break;
  213. default:
  214. break;
  215. }
  216. });
  217. has_uniform_sync_ = has_sync;
  218. return has_sync;
  219. }
  220. bool CodeSinkingPass::IsSyncOnUniform(uint32_t mem_semantics_id) const {
  221. const analysis::Constant* mem_semantics_const =
  222. context()->get_constant_mgr()->FindDeclaredConstant(mem_semantics_id);
  223. assert(mem_semantics_const != nullptr &&
  224. "Expecting memory semantics id to be a constant.");
  225. assert(mem_semantics_const->AsIntConstant() &&
  226. "Memory semantics should be an integer.");
  227. uint32_t mem_semantics_int = mem_semantics_const->GetU32();
  228. // If it does not affect uniform memory, then it is does not apply to uniform
  229. // memory.
  230. if ((mem_semantics_int & SpvMemorySemanticsUniformMemoryMask) == 0) {
  231. return false;
  232. }
  233. // Check if there is an acquire or release. If so not, this it does not add
  234. // any memory constraints.
  235. return (mem_semantics_int & (SpvMemorySemanticsAcquireMask |
  236. SpvMemorySemanticsAcquireReleaseMask |
  237. SpvMemorySemanticsReleaseMask)) != 0;
  238. }
  239. bool CodeSinkingPass::HasPossibleStore(Instruction* var_inst) {
  240. assert(var_inst->opcode() == SpvOpVariable ||
  241. var_inst->opcode() == SpvOpAccessChain ||
  242. var_inst->opcode() == SpvOpPtrAccessChain);
  243. return get_def_use_mgr()->WhileEachUser(var_inst, [this](Instruction* use) {
  244. switch (use->opcode()) {
  245. case SpvOpStore:
  246. return true;
  247. case SpvOpAccessChain:
  248. case SpvOpPtrAccessChain:
  249. return HasPossibleStore(use);
  250. default:
  251. return false;
  252. }
  253. });
  254. }
  255. bool CodeSinkingPass::IntersectsPath(uint32_t start, uint32_t end,
  256. const std::unordered_set<uint32_t>& set) {
  257. std::vector<uint32_t> worklist;
  258. worklist.push_back(start);
  259. std::unordered_set<uint32_t> already_done;
  260. already_done.insert(start);
  261. while (!worklist.empty()) {
  262. BasicBlock* bb = context()->get_instr_block(worklist.back());
  263. worklist.pop_back();
  264. if (bb->id() == end) {
  265. continue;
  266. }
  267. if (set.count(bb->id())) {
  268. return true;
  269. }
  270. bb->ForEachSuccessorLabel([&already_done, &worklist](uint32_t* succ_bb_id) {
  271. if (already_done.insert(*succ_bb_id).second) {
  272. worklist.push_back(*succ_bb_id);
  273. }
  274. });
  275. }
  276. return false;
  277. }
  278. // namespace opt
  279. } // namespace opt
  280. } // namespace spvtools