code_sink.cpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. // Copyright (c) 2019 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "code_sink.h"
  15. #include <set>
  16. #include <vector>
  17. #include "source/opt/instruction.h"
  18. #include "source/opt/ir_builder.h"
  19. #include "source/opt/ir_context.h"
  20. #include "source/util/bit_vector.h"
  21. namespace spvtools {
  22. namespace opt {
  23. Pass::Status CodeSinkingPass::Process() {
  24. bool modified = false;
  25. for (Function& function : *get_module()) {
  26. cfg()->ForEachBlockInPostOrder(function.entry().get(),
  27. [&modified, this](BasicBlock* bb) {
  28. if (SinkInstructionsInBB(bb)) {
  29. modified = true;
  30. }
  31. });
  32. }
  33. return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
  34. }
  35. bool CodeSinkingPass::SinkInstructionsInBB(BasicBlock* bb) {
  36. bool modified = false;
  37. for (auto inst = bb->rbegin(); inst != bb->rend(); ++inst) {
  38. if (SinkInstruction(&*inst)) {
  39. inst = bb->rbegin();
  40. modified = true;
  41. }
  42. }
  43. return modified;
  44. }
  45. bool CodeSinkingPass::SinkInstruction(Instruction* inst) {
  46. if (inst->opcode() != spv::Op::OpLoad &&
  47. inst->opcode() != spv::Op::OpAccessChain) {
  48. return false;
  49. }
  50. if (ReferencesMutableMemory(inst)) {
  51. return false;
  52. }
  53. if (BasicBlock* target_bb = FindNewBasicBlockFor(inst)) {
  54. Instruction* pos = &*target_bb->begin();
  55. while (pos->opcode() == spv::Op::OpPhi) {
  56. pos = pos->NextNode();
  57. }
  58. inst->InsertBefore(pos);
  59. context()->set_instr_block(inst, target_bb);
  60. return true;
  61. }
  62. return false;
  63. }
  64. BasicBlock* CodeSinkingPass::FindNewBasicBlockFor(Instruction* inst) {
  65. assert(inst->result_id() != 0 && "Instruction should have a result.");
  66. BasicBlock* original_bb = context()->get_instr_block(inst);
  67. BasicBlock* bb = original_bb;
  68. std::unordered_set<uint32_t> bbs_with_uses;
  69. get_def_use_mgr()->ForEachUse(
  70. inst, [&bbs_with_uses, this](Instruction* use, uint32_t idx) {
  71. if (use->opcode() != spv::Op::OpPhi) {
  72. BasicBlock* use_bb = context()->get_instr_block(use);
  73. if (use_bb) {
  74. bbs_with_uses.insert(use_bb->id());
  75. }
  76. } else {
  77. bbs_with_uses.insert(use->GetSingleWordOperand(idx + 1));
  78. }
  79. });
  80. while (true) {
  81. // If |inst| is used in |bb|, then |inst| cannot be moved any further.
  82. if (bbs_with_uses.count(bb->id())) {
  83. break;
  84. }
  85. // If |bb| has one successor (succ_bb), and |bb| is the only predecessor
  86. // of succ_bb, then |inst| can be moved to succ_bb. If succ_bb, has move
  87. // then one predecessor, then moving |inst| into succ_bb could cause it to
  88. // be executed more often, so the search has to stop.
  89. if (bb->terminator()->opcode() == spv::Op::OpBranch) {
  90. uint32_t succ_bb_id = bb->terminator()->GetSingleWordInOperand(0);
  91. if (cfg()->preds(succ_bb_id).size() == 1) {
  92. bb = context()->get_instr_block(succ_bb_id);
  93. continue;
  94. } else {
  95. break;
  96. }
  97. }
  98. // The remaining checks need to know the merge node. If there is no merge
  99. // instruction or an OpLoopMerge, then it is a break or continue. We could
  100. // figure it out, but not worth doing it now.
  101. Instruction* merge_inst = bb->GetMergeInst();
  102. if (merge_inst == nullptr ||
  103. merge_inst->opcode() != spv::Op::OpSelectionMerge) {
  104. break;
  105. }
  106. // Check all of the successors of |bb| it see which lead to a use of |inst|
  107. // before reaching the merge node.
  108. bool used_in_multiple_blocks = false;
  109. uint32_t bb_used_in = 0;
  110. bb->ForEachSuccessorLabel([this, bb, &bb_used_in, &used_in_multiple_blocks,
  111. &bbs_with_uses](uint32_t* succ_bb_id) {
  112. if (IntersectsPath(*succ_bb_id, bb->MergeBlockIdIfAny(), bbs_with_uses)) {
  113. if (bb_used_in == 0) {
  114. bb_used_in = *succ_bb_id;
  115. } else {
  116. used_in_multiple_blocks = true;
  117. }
  118. }
  119. });
  120. // If more than one successor, which is not the merge block, uses |inst|
  121. // then we have to leave |inst| in bb because there is none of the
  122. // successors dominate all uses of |inst|.
  123. if (used_in_multiple_blocks) {
  124. break;
  125. }
  126. if (bb_used_in == 0) {
  127. // If |inst| is not used before reaching the merge node, then we can move
  128. // |inst| to the merge node.
  129. bb = context()->get_instr_block(bb->MergeBlockIdIfAny());
  130. } else {
  131. // If the only successor that leads to a used of |inst| has more than 1
  132. // predecessor, then moving |inst| could cause it to be executed more
  133. // often, so we cannot move it.
  134. if (cfg()->preds(bb_used_in).size() != 1) {
  135. break;
  136. }
  137. // If |inst| is used after the merge block, then |bb_used_in| does not
  138. // dominate all of the uses. So we cannot move |inst| any further.
  139. if (IntersectsPath(bb->MergeBlockIdIfAny(), original_bb->id(),
  140. bbs_with_uses)) {
  141. break;
  142. }
  143. // Otherwise, |bb_used_in| dominates all uses, so move |inst| into that
  144. // block.
  145. bb = context()->get_instr_block(bb_used_in);
  146. }
  147. continue;
  148. }
  149. return (bb != original_bb ? bb : nullptr);
  150. }
  151. bool CodeSinkingPass::ReferencesMutableMemory(Instruction* inst) {
  152. if (!inst->IsLoad()) {
  153. return false;
  154. }
  155. Instruction* base_ptr = inst->GetBaseAddress();
  156. if (base_ptr->opcode() != spv::Op::OpVariable) {
  157. return true;
  158. }
  159. if (base_ptr->IsReadOnlyPointer()) {
  160. return false;
  161. }
  162. if (HasUniformMemorySync()) {
  163. return true;
  164. }
  165. if (spv::StorageClass(base_ptr->GetSingleWordInOperand(0)) !=
  166. spv::StorageClass::Uniform) {
  167. return true;
  168. }
  169. return HasPossibleStore(base_ptr);
  170. }
  171. bool CodeSinkingPass::HasUniformMemorySync() {
  172. if (checked_for_uniform_sync_) {
  173. return has_uniform_sync_;
  174. }
  175. bool has_sync = false;
  176. get_module()->ForEachInst([this, &has_sync](Instruction* inst) {
  177. switch (inst->opcode()) {
  178. case spv::Op::OpMemoryBarrier: {
  179. uint32_t mem_semantics_id = inst->GetSingleWordInOperand(1);
  180. if (IsSyncOnUniform(mem_semantics_id)) {
  181. has_sync = true;
  182. }
  183. break;
  184. }
  185. case spv::Op::OpControlBarrier:
  186. case spv::Op::OpAtomicLoad:
  187. case spv::Op::OpAtomicStore:
  188. case spv::Op::OpAtomicExchange:
  189. case spv::Op::OpAtomicIIncrement:
  190. case spv::Op::OpAtomicIDecrement:
  191. case spv::Op::OpAtomicIAdd:
  192. case spv::Op::OpAtomicFAddEXT:
  193. case spv::Op::OpAtomicISub:
  194. case spv::Op::OpAtomicSMin:
  195. case spv::Op::OpAtomicUMin:
  196. case spv::Op::OpAtomicFMinEXT:
  197. case spv::Op::OpAtomicSMax:
  198. case spv::Op::OpAtomicUMax:
  199. case spv::Op::OpAtomicFMaxEXT:
  200. case spv::Op::OpAtomicAnd:
  201. case spv::Op::OpAtomicOr:
  202. case spv::Op::OpAtomicXor:
  203. case spv::Op::OpAtomicFlagTestAndSet:
  204. case spv::Op::OpAtomicFlagClear: {
  205. uint32_t mem_semantics_id = inst->GetSingleWordInOperand(2);
  206. if (IsSyncOnUniform(mem_semantics_id)) {
  207. has_sync = true;
  208. }
  209. break;
  210. }
  211. case spv::Op::OpAtomicCompareExchange:
  212. case spv::Op::OpAtomicCompareExchangeWeak:
  213. if (IsSyncOnUniform(inst->GetSingleWordInOperand(2)) ||
  214. IsSyncOnUniform(inst->GetSingleWordInOperand(3))) {
  215. has_sync = true;
  216. }
  217. break;
  218. default:
  219. break;
  220. }
  221. });
  222. has_uniform_sync_ = has_sync;
  223. return has_sync;
  224. }
  225. bool CodeSinkingPass::IsSyncOnUniform(uint32_t mem_semantics_id) const {
  226. const analysis::Constant* mem_semantics_const =
  227. context()->get_constant_mgr()->FindDeclaredConstant(mem_semantics_id);
  228. assert(mem_semantics_const != nullptr &&
  229. "Expecting memory semantics id to be a constant.");
  230. assert(mem_semantics_const->AsIntConstant() &&
  231. "Memory semantics should be an integer.");
  232. uint32_t mem_semantics_int = mem_semantics_const->GetU32();
  233. // If it does not affect uniform memory, then it is does not apply to uniform
  234. // memory.
  235. if ((mem_semantics_int & uint32_t(spv::MemorySemanticsMask::UniformMemory)) ==
  236. 0) {
  237. return false;
  238. }
  239. // Check if there is an acquire or release. If so not, this it does not add
  240. // any memory constraints.
  241. return (mem_semantics_int &
  242. uint32_t(spv::MemorySemanticsMask::Acquire |
  243. spv::MemorySemanticsMask::AcquireRelease |
  244. spv::MemorySemanticsMask::Release)) != 0;
  245. }
  246. bool CodeSinkingPass::HasPossibleStore(Instruction* var_inst) {
  247. assert(var_inst->opcode() == spv::Op::OpVariable ||
  248. var_inst->opcode() == spv::Op::OpAccessChain ||
  249. var_inst->opcode() == spv::Op::OpPtrAccessChain);
  250. return get_def_use_mgr()->WhileEachUser(var_inst, [this](Instruction* use) {
  251. switch (use->opcode()) {
  252. case spv::Op::OpStore:
  253. return true;
  254. case spv::Op::OpAccessChain:
  255. case spv::Op::OpPtrAccessChain:
  256. return HasPossibleStore(use);
  257. default:
  258. return false;
  259. }
  260. });
  261. }
  262. bool CodeSinkingPass::IntersectsPath(uint32_t start, uint32_t end,
  263. const std::unordered_set<uint32_t>& set) {
  264. std::vector<uint32_t> worklist;
  265. worklist.push_back(start);
  266. std::unordered_set<uint32_t> already_done;
  267. already_done.insert(start);
  268. while (!worklist.empty()) {
  269. BasicBlock* bb = context()->get_instr_block(worklist.back());
  270. worklist.pop_back();
  271. if (bb->id() == end) {
  272. continue;
  273. }
  274. if (set.count(bb->id())) {
  275. return true;
  276. }
  277. bb->ForEachSuccessorLabel([&already_done, &worklist](uint32_t* succ_bb_id) {
  278. if (already_done.insert(*succ_bb_id).second) {
  279. worklist.push_back(*succ_bb_id);
  280. }
  281. });
  282. }
  283. return false;
  284. }
  285. // namespace opt
  286. } // namespace opt
  287. } // namespace spvtools