code_sink.cpp 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. // Copyright (c) 2019 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "code_sink.h"
  15. #include <set>
  16. #include <vector>
  17. #include "source/opt/instruction.h"
  18. #include "source/opt/ir_builder.h"
  19. #include "source/opt/ir_context.h"
  20. #include "source/util/bit_vector.h"
  21. namespace spvtools {
  22. namespace opt {
  23. Pass::Status CodeSinkingPass::Process() {
  24. bool modified = false;
  25. for (Function& function : *get_module()) {
  26. cfg()->ForEachBlockInPostOrder(function.entry().get(),
  27. [&modified, this](BasicBlock* bb) {
  28. if (SinkInstructionsInBB(bb)) {
  29. modified = true;
  30. }
  31. });
  32. }
  33. return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
  34. }
  35. bool CodeSinkingPass::SinkInstructionsInBB(BasicBlock* bb) {
  36. bool modified = false;
  37. for (auto inst = bb->rbegin(); inst != bb->rend(); ++inst) {
  38. if (SinkInstruction(&*inst)) {
  39. inst = bb->rbegin();
  40. modified = true;
  41. }
  42. }
  43. return modified;
  44. }
  45. bool CodeSinkingPass::SinkInstruction(Instruction* inst) {
  46. if (inst->opcode() != SpvOpLoad && inst->opcode() != SpvOpAccessChain) {
  47. return false;
  48. }
  49. if (ReferencesMutableMemory(inst)) {
  50. return false;
  51. }
  52. if (BasicBlock* target_bb = FindNewBasicBlockFor(inst)) {
  53. Instruction* pos = &*target_bb->begin();
  54. while (pos->opcode() == SpvOpPhi) {
  55. pos = pos->NextNode();
  56. }
  57. inst->InsertBefore(pos);
  58. context()->set_instr_block(inst, target_bb);
  59. return true;
  60. }
  61. return false;
  62. }
  63. BasicBlock* CodeSinkingPass::FindNewBasicBlockFor(Instruction* inst) {
  64. assert(inst->result_id() != 0 && "Instruction should have a result.");
  65. BasicBlock* original_bb = context()->get_instr_block(inst);
  66. BasicBlock* bb = original_bb;
  67. std::unordered_set<uint32_t> bbs_with_uses;
  68. get_def_use_mgr()->ForEachUse(
  69. inst, [&bbs_with_uses, this](Instruction* use, uint32_t idx) {
  70. if (use->opcode() != SpvOpPhi) {
  71. BasicBlock* use_bb = context()->get_instr_block(use);
  72. if (use_bb) {
  73. bbs_with_uses.insert(use_bb->id());
  74. }
  75. } else {
  76. bbs_with_uses.insert(use->GetSingleWordOperand(idx + 1));
  77. }
  78. });
  79. while (true) {
  80. // If |inst| is used in |bb|, then |inst| cannot be moved any further.
  81. if (bbs_with_uses.count(bb->id())) {
  82. break;
  83. }
  84. // If |bb| has one successor (succ_bb), and |bb| is the only predecessor
  85. // of succ_bb, then |inst| can be moved to succ_bb. If succ_bb, has move
  86. // then one predecessor, then moving |inst| into succ_bb could cause it to
  87. // be executed more often, so the search has to stop.
  88. if (bb->terminator()->opcode() == SpvOpBranch) {
  89. uint32_t succ_bb_id = bb->terminator()->GetSingleWordInOperand(0);
  90. if (cfg()->preds(succ_bb_id).size() == 1) {
  91. bb = context()->get_instr_block(succ_bb_id);
  92. continue;
  93. } else {
  94. break;
  95. }
  96. }
  97. // The remaining checks need to know the merge node. If there is no merge
  98. // instruction or an OpLoopMerge, then it is a break or continue. We could
  99. // figure it out, but not worth doing it now.
  100. Instruction* merge_inst = bb->GetMergeInst();
  101. if (merge_inst == nullptr || merge_inst->opcode() != SpvOpSelectionMerge) {
  102. break;
  103. }
  104. // Check all of the successors of |bb| it see which lead to a use of |inst|
  105. // before reaching the merge node.
  106. bool used_in_multiple_blocks = false;
  107. uint32_t bb_used_in = 0;
  108. bb->ForEachSuccessorLabel([this, bb, &bb_used_in, &used_in_multiple_blocks,
  109. &bbs_with_uses](uint32_t* succ_bb_id) {
  110. if (IntersectsPath(*succ_bb_id, bb->MergeBlockIdIfAny(), bbs_with_uses)) {
  111. if (bb_used_in == 0) {
  112. bb_used_in = *succ_bb_id;
  113. } else {
  114. used_in_multiple_blocks = true;
  115. }
  116. }
  117. });
  118. // If more than one successor, which is not the merge block, uses |inst|
  119. // then we have to leave |inst| in bb because there is none of the
  120. // successors dominate all uses of |inst|.
  121. if (used_in_multiple_blocks) {
  122. break;
  123. }
  124. if (bb_used_in == 0) {
  125. // If |inst| is not used before reaching the merge node, then we can move
  126. // |inst| to the merge node.
  127. bb = context()->get_instr_block(bb->MergeBlockIdIfAny());
  128. } else {
  129. // If the only successor that leads to a used of |inst| has more than 1
  130. // predecessor, then moving |inst| could cause it to be executed more
  131. // often, so we cannot move it.
  132. if (cfg()->preds(bb_used_in).size() != 1) {
  133. break;
  134. }
  135. // If |inst| is used after the merge block, then |bb_used_in| does not
  136. // dominate all of the uses. So we cannot move |inst| any further.
  137. if (IntersectsPath(bb->MergeBlockIdIfAny(), original_bb->id(),
  138. bbs_with_uses)) {
  139. break;
  140. }
  141. // Otherwise, |bb_used_in| dominates all uses, so move |inst| into that
  142. // block.
  143. bb = context()->get_instr_block(bb_used_in);
  144. }
  145. continue;
  146. }
  147. return (bb != original_bb ? bb : nullptr);
  148. }
  149. bool CodeSinkingPass::ReferencesMutableMemory(Instruction* inst) {
  150. if (!inst->IsLoad()) {
  151. return false;
  152. }
  153. Instruction* base_ptr = inst->GetBaseAddress();
  154. if (base_ptr->opcode() != SpvOpVariable) {
  155. return true;
  156. }
  157. if (base_ptr->IsReadOnlyPointer()) {
  158. return false;
  159. }
  160. if (HasUniformMemorySync()) {
  161. return true;
  162. }
  163. if (base_ptr->GetSingleWordInOperand(0) != SpvStorageClassUniform) {
  164. return true;
  165. }
  166. return HasPossibleStore(base_ptr);
  167. }
  168. bool CodeSinkingPass::HasUniformMemorySync() {
  169. if (checked_for_uniform_sync_) {
  170. return has_uniform_sync_;
  171. }
  172. bool has_sync = false;
  173. get_module()->ForEachInst([this, &has_sync](Instruction* inst) {
  174. switch (inst->opcode()) {
  175. case SpvOpMemoryBarrier: {
  176. uint32_t mem_semantics_id = inst->GetSingleWordInOperand(1);
  177. if (IsSyncOnUniform(mem_semantics_id)) {
  178. has_sync = true;
  179. }
  180. break;
  181. }
  182. case SpvOpControlBarrier:
  183. case SpvOpAtomicLoad:
  184. case SpvOpAtomicStore:
  185. case SpvOpAtomicExchange:
  186. case SpvOpAtomicIIncrement:
  187. case SpvOpAtomicIDecrement:
  188. case SpvOpAtomicIAdd:
  189. case SpvOpAtomicFAddEXT:
  190. case SpvOpAtomicISub:
  191. case SpvOpAtomicSMin:
  192. case SpvOpAtomicUMin:
  193. case SpvOpAtomicFMinEXT:
  194. case SpvOpAtomicSMax:
  195. case SpvOpAtomicUMax:
  196. case SpvOpAtomicFMaxEXT:
  197. case SpvOpAtomicAnd:
  198. case SpvOpAtomicOr:
  199. case SpvOpAtomicXor:
  200. case SpvOpAtomicFlagTestAndSet:
  201. case SpvOpAtomicFlagClear: {
  202. uint32_t mem_semantics_id = inst->GetSingleWordInOperand(2);
  203. if (IsSyncOnUniform(mem_semantics_id)) {
  204. has_sync = true;
  205. }
  206. break;
  207. }
  208. case SpvOpAtomicCompareExchange:
  209. case SpvOpAtomicCompareExchangeWeak:
  210. if (IsSyncOnUniform(inst->GetSingleWordInOperand(2)) ||
  211. IsSyncOnUniform(inst->GetSingleWordInOperand(3))) {
  212. has_sync = true;
  213. }
  214. break;
  215. default:
  216. break;
  217. }
  218. });
  219. has_uniform_sync_ = has_sync;
  220. return has_sync;
  221. }
  222. bool CodeSinkingPass::IsSyncOnUniform(uint32_t mem_semantics_id) const {
  223. const analysis::Constant* mem_semantics_const =
  224. context()->get_constant_mgr()->FindDeclaredConstant(mem_semantics_id);
  225. assert(mem_semantics_const != nullptr &&
  226. "Expecting memory semantics id to be a constant.");
  227. assert(mem_semantics_const->AsIntConstant() &&
  228. "Memory semantics should be an integer.");
  229. uint32_t mem_semantics_int = mem_semantics_const->GetU32();
  230. // If it does not affect uniform memory, then it is does not apply to uniform
  231. // memory.
  232. if ((mem_semantics_int & SpvMemorySemanticsUniformMemoryMask) == 0) {
  233. return false;
  234. }
  235. // Check if there is an acquire or release. If so not, this it does not add
  236. // any memory constraints.
  237. return (mem_semantics_int & (SpvMemorySemanticsAcquireMask |
  238. SpvMemorySemanticsAcquireReleaseMask |
  239. SpvMemorySemanticsReleaseMask)) != 0;
  240. }
  241. bool CodeSinkingPass::HasPossibleStore(Instruction* var_inst) {
  242. assert(var_inst->opcode() == SpvOpVariable ||
  243. var_inst->opcode() == SpvOpAccessChain ||
  244. var_inst->opcode() == SpvOpPtrAccessChain);
  245. return get_def_use_mgr()->WhileEachUser(var_inst, [this](Instruction* use) {
  246. switch (use->opcode()) {
  247. case SpvOpStore:
  248. return true;
  249. case SpvOpAccessChain:
  250. case SpvOpPtrAccessChain:
  251. return HasPossibleStore(use);
  252. default:
  253. return false;
  254. }
  255. });
  256. }
  257. bool CodeSinkingPass::IntersectsPath(uint32_t start, uint32_t end,
  258. const std::unordered_set<uint32_t>& set) {
  259. std::vector<uint32_t> worklist;
  260. worklist.push_back(start);
  261. std::unordered_set<uint32_t> already_done;
  262. already_done.insert(start);
  263. while (!worklist.empty()) {
  264. BasicBlock* bb = context()->get_instr_block(worklist.back());
  265. worklist.pop_back();
  266. if (bb->id() == end) {
  267. continue;
  268. }
  269. if (set.count(bb->id())) {
  270. return true;
  271. }
  272. bb->ForEachSuccessorLabel([&already_done, &worklist](uint32_t* succ_bb_id) {
  273. if (already_done.insert(*succ_bb_id).second) {
  274. worklist.push_back(*succ_bb_id);
  275. }
  276. });
  277. }
  278. return false;
  279. }
  280. // namespace opt
  281. } // namespace opt
  282. } // namespace spvtools