code_sink.cpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. // Copyright (c) 2019 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "code_sink.h"
  15. #include <vector>
  16. #include "source/opt/instruction.h"
  17. #include "source/opt/ir_context.h"
  18. #include "source/util/bit_vector.h"
  19. namespace spvtools {
  20. namespace opt {
  21. Pass::Status CodeSinkingPass::Process() {
  22. bool modified = false;
  23. for (Function& function : *get_module()) {
  24. cfg()->ForEachBlockInPostOrder(function.entry().get(),
  25. [&modified, this](BasicBlock* bb) {
  26. if (SinkInstructionsInBB(bb)) {
  27. modified = true;
  28. }
  29. });
  30. }
  31. return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
  32. }
  33. bool CodeSinkingPass::SinkInstructionsInBB(BasicBlock* bb) {
  34. bool modified = false;
  35. for (auto inst = bb->rbegin(); inst != bb->rend(); ++inst) {
  36. if (SinkInstruction(&*inst)) {
  37. inst = bb->rbegin();
  38. modified = true;
  39. }
  40. }
  41. return modified;
  42. }
  43. bool CodeSinkingPass::SinkInstruction(Instruction* inst) {
  44. if (inst->opcode() != spv::Op::OpLoad &&
  45. inst->opcode() != spv::Op::OpAccessChain) {
  46. return false;
  47. }
  48. if (ReferencesMutableMemory(inst)) {
  49. return false;
  50. }
  51. if (BasicBlock* target_bb = FindNewBasicBlockFor(inst)) {
  52. Instruction* pos = &*target_bb->begin();
  53. while (pos->opcode() == spv::Op::OpPhi) {
  54. pos = pos->NextNode();
  55. }
  56. inst->InsertBefore(pos);
  57. context()->set_instr_block(inst, target_bb);
  58. return true;
  59. }
  60. return false;
  61. }
  62. BasicBlock* CodeSinkingPass::FindNewBasicBlockFor(Instruction* inst) {
  63. assert(inst->result_id() != 0 && "Instruction should have a result.");
  64. BasicBlock* original_bb = context()->get_instr_block(inst);
  65. BasicBlock* bb = original_bb;
  66. std::unordered_set<uint32_t> bbs_with_uses;
  67. get_def_use_mgr()->ForEachUse(
  68. inst, [&bbs_with_uses, this](Instruction* use, uint32_t idx) {
  69. if (use->opcode() != spv::Op::OpPhi) {
  70. BasicBlock* use_bb = context()->get_instr_block(use);
  71. if (use_bb) {
  72. bbs_with_uses.insert(use_bb->id());
  73. }
  74. } else {
  75. bbs_with_uses.insert(use->GetSingleWordOperand(idx + 1));
  76. }
  77. });
  78. while (true) {
  79. // If |inst| is used in |bb|, then |inst| cannot be moved any further.
  80. if (bbs_with_uses.count(bb->id())) {
  81. break;
  82. }
  83. // If |bb| has one successor (succ_bb), and |bb| is the only predecessor
  84. // of succ_bb, then |inst| can be moved to succ_bb. If succ_bb, has move
  85. // then one predecessor, then moving |inst| into succ_bb could cause it to
  86. // be executed more often, so the search has to stop.
  87. if (bb->terminator()->opcode() == spv::Op::OpBranch) {
  88. uint32_t succ_bb_id = bb->terminator()->GetSingleWordInOperand(0);
  89. if (cfg()->preds(succ_bb_id).size() == 1) {
  90. bb = context()->get_instr_block(succ_bb_id);
  91. continue;
  92. } else {
  93. break;
  94. }
  95. }
  96. // The remaining checks need to know the merge node. If there is no merge
  97. // instruction or an OpLoopMerge, then it is a break or continue. We could
  98. // figure it out, but not worth doing it now.
  99. Instruction* merge_inst = bb->GetMergeInst();
  100. if (merge_inst == nullptr ||
  101. merge_inst->opcode() != spv::Op::OpSelectionMerge) {
  102. break;
  103. }
  104. // Check all of the successors of |bb| it see which lead to a use of |inst|
  105. // before reaching the merge node.
  106. bool used_in_multiple_blocks = false;
  107. uint32_t bb_used_in = 0;
  108. bb->ForEachSuccessorLabel([this, bb, &bb_used_in, &used_in_multiple_blocks,
  109. &bbs_with_uses](uint32_t* succ_bb_id) {
  110. if (IntersectsPath(*succ_bb_id, bb->MergeBlockIdIfAny(), bbs_with_uses)) {
  111. if (bb_used_in == 0) {
  112. bb_used_in = *succ_bb_id;
  113. } else {
  114. used_in_multiple_blocks = true;
  115. }
  116. }
  117. });
  118. // If more than one successor, which is not the merge block, uses |inst|
  119. // then we have to leave |inst| in bb because there is none of the
  120. // successors dominate all uses of |inst|.
  121. if (used_in_multiple_blocks) {
  122. break;
  123. }
  124. if (bb_used_in == 0) {
  125. // If |inst| is not used before reaching the merge node, then we can move
  126. // |inst| to the merge node.
  127. bb = context()->get_instr_block(bb->MergeBlockIdIfAny());
  128. } else {
  129. // If the only successor that leads to a used of |inst| has more than 1
  130. // predecessor, then moving |inst| could cause it to be executed more
  131. // often, so we cannot move it.
  132. if (cfg()->preds(bb_used_in).size() != 1) {
  133. break;
  134. }
  135. // If |inst| is used after the merge block, then |bb_used_in| does not
  136. // dominate all of the uses. So we cannot move |inst| any further.
  137. if (IntersectsPath(bb->MergeBlockIdIfAny(), original_bb->id(),
  138. bbs_with_uses)) {
  139. break;
  140. }
  141. // Otherwise, |bb_used_in| dominates all uses, so move |inst| into that
  142. // block.
  143. bb = context()->get_instr_block(bb_used_in);
  144. }
  145. continue;
  146. }
  147. return (bb != original_bb ? bb : nullptr);
  148. }
  149. bool CodeSinkingPass::ReferencesMutableMemory(Instruction* inst) {
  150. if (!inst->IsLoad()) {
  151. return false;
  152. }
  153. Instruction* base_ptr = inst->GetBaseAddress();
  154. if (base_ptr->opcode() != spv::Op::OpVariable) {
  155. return true;
  156. }
  157. if (base_ptr->IsReadOnlyPointer()) {
  158. return false;
  159. }
  160. if (HasUniformMemorySync()) {
  161. return true;
  162. }
  163. if (spv::StorageClass(base_ptr->GetSingleWordInOperand(0)) !=
  164. spv::StorageClass::Uniform) {
  165. return true;
  166. }
  167. return HasPossibleStore(base_ptr);
  168. }
  169. bool CodeSinkingPass::HasUniformMemorySync() {
  170. if (checked_for_uniform_sync_) {
  171. return has_uniform_sync_;
  172. }
  173. bool has_sync = false;
  174. get_module()->ForEachInst([this, &has_sync](Instruction* inst) {
  175. switch (inst->opcode()) {
  176. case spv::Op::OpMemoryBarrier: {
  177. uint32_t mem_semantics_id = inst->GetSingleWordInOperand(1);
  178. if (IsSyncOnUniform(mem_semantics_id)) {
  179. has_sync = true;
  180. }
  181. break;
  182. }
  183. case spv::Op::OpControlBarrier:
  184. case spv::Op::OpAtomicLoad:
  185. case spv::Op::OpAtomicStore:
  186. case spv::Op::OpAtomicExchange:
  187. case spv::Op::OpAtomicIIncrement:
  188. case spv::Op::OpAtomicIDecrement:
  189. case spv::Op::OpAtomicIAdd:
  190. case spv::Op::OpAtomicFAddEXT:
  191. case spv::Op::OpAtomicISub:
  192. case spv::Op::OpAtomicSMin:
  193. case spv::Op::OpAtomicUMin:
  194. case spv::Op::OpAtomicFMinEXT:
  195. case spv::Op::OpAtomicSMax:
  196. case spv::Op::OpAtomicUMax:
  197. case spv::Op::OpAtomicFMaxEXT:
  198. case spv::Op::OpAtomicAnd:
  199. case spv::Op::OpAtomicOr:
  200. case spv::Op::OpAtomicXor:
  201. case spv::Op::OpAtomicFlagTestAndSet:
  202. case spv::Op::OpAtomicFlagClear: {
  203. uint32_t mem_semantics_id = inst->GetSingleWordInOperand(2);
  204. if (IsSyncOnUniform(mem_semantics_id)) {
  205. has_sync = true;
  206. }
  207. break;
  208. }
  209. case spv::Op::OpAtomicCompareExchange:
  210. case spv::Op::OpAtomicCompareExchangeWeak:
  211. if (IsSyncOnUniform(inst->GetSingleWordInOperand(2)) ||
  212. IsSyncOnUniform(inst->GetSingleWordInOperand(3))) {
  213. has_sync = true;
  214. }
  215. break;
  216. default:
  217. break;
  218. }
  219. });
  220. has_uniform_sync_ = has_sync;
  221. return has_sync;
  222. }
  223. bool CodeSinkingPass::IsSyncOnUniform(uint32_t mem_semantics_id) const {
  224. const analysis::Constant* mem_semantics_const =
  225. context()->get_constant_mgr()->FindDeclaredConstant(mem_semantics_id);
  226. assert(mem_semantics_const != nullptr &&
  227. "Expecting memory semantics id to be a constant.");
  228. assert(mem_semantics_const->AsIntConstant() &&
  229. "Memory semantics should be an integer.");
  230. uint32_t mem_semantics_int = mem_semantics_const->GetU32();
  231. // If it does not affect uniform memory, then it is does not apply to uniform
  232. // memory.
  233. if ((mem_semantics_int & uint32_t(spv::MemorySemanticsMask::UniformMemory)) ==
  234. 0) {
  235. return false;
  236. }
  237. // Check if there is an acquire or release. If so not, this it does not add
  238. // any memory constraints.
  239. return (mem_semantics_int &
  240. uint32_t(spv::MemorySemanticsMask::Acquire |
  241. spv::MemorySemanticsMask::AcquireRelease |
  242. spv::MemorySemanticsMask::Release)) != 0;
  243. }
  244. bool CodeSinkingPass::HasPossibleStore(Instruction* var_inst) {
  245. assert(var_inst->opcode() == spv::Op::OpVariable ||
  246. var_inst->opcode() == spv::Op::OpAccessChain ||
  247. var_inst->opcode() == spv::Op::OpPtrAccessChain);
  248. return get_def_use_mgr()->WhileEachUser(var_inst, [this](Instruction* use) {
  249. switch (use->opcode()) {
  250. case spv::Op::OpStore:
  251. return true;
  252. case spv::Op::OpAccessChain:
  253. case spv::Op::OpPtrAccessChain:
  254. return HasPossibleStore(use);
  255. default:
  256. return false;
  257. }
  258. });
  259. }
  260. bool CodeSinkingPass::IntersectsPath(uint32_t start, uint32_t end,
  261. const std::unordered_set<uint32_t>& set) {
  262. std::vector<uint32_t> worklist;
  263. worklist.push_back(start);
  264. std::unordered_set<uint32_t> already_done;
  265. already_done.insert(start);
  266. while (!worklist.empty()) {
  267. BasicBlock* bb = context()->get_instr_block(worklist.back());
  268. worklist.pop_back();
  269. if (bb->id() == end) {
  270. continue;
  271. }
  272. if (set.count(bb->id())) {
  273. return true;
  274. }
  275. bb->ForEachSuccessorLabel([&already_done, &worklist](uint32_t* succ_bb_id) {
  276. if (already_done.insert(*succ_bb_id).second) {
  277. worklist.push_back(*succ_bb_id);
  278. }
  279. });
  280. }
  281. return false;
  282. }
  283. // namespace opt
  284. } // namespace opt
  285. } // namespace spvtools