mem_pass.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499
  1. // Copyright (c) 2017 The Khronos Group Inc.
  2. // Copyright (c) 2017 Valve Corporation
  3. // Copyright (c) 2017 LunarG Inc.
  4. //
  5. // Licensed under the Apache License, Version 2.0 (the "License");
  6. // you may not use this file except in compliance with the License.
  7. // You may obtain a copy of the License at
  8. //
  9. // http://www.apache.org/licenses/LICENSE-2.0
  10. //
  11. // Unless required by applicable law or agreed to in writing, software
  12. // distributed under the License is distributed on an "AS IS" BASIS,
  13. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. // See the License for the specific language governing permissions and
  15. // limitations under the License.
  16. #include "source/opt/mem_pass.h"
  17. #include <memory>
  18. #include <set>
  19. #include <vector>
  20. #include "source/cfa.h"
  21. #include "source/opt/basic_block.h"
  22. #include "source/opt/dominator_analysis.h"
  23. #include "source/opt/ir_context.h"
  24. #include "source/opt/iterator.h"
  25. namespace spvtools {
  26. namespace opt {
  27. namespace {
  28. const uint32_t kCopyObjectOperandInIdx = 0;
  29. const uint32_t kTypePointerStorageClassInIdx = 0;
  30. const uint32_t kTypePointerTypeIdInIdx = 1;
  31. } // namespace
  32. bool MemPass::IsBaseTargetType(const Instruction* typeInst) const {
  33. switch (typeInst->opcode()) {
  34. case SpvOpTypeInt:
  35. case SpvOpTypeFloat:
  36. case SpvOpTypeBool:
  37. case SpvOpTypeVector:
  38. case SpvOpTypeMatrix:
  39. case SpvOpTypeImage:
  40. case SpvOpTypeSampler:
  41. case SpvOpTypeSampledImage:
  42. case SpvOpTypePointer:
  43. return true;
  44. default:
  45. break;
  46. }
  47. return false;
  48. }
  49. bool MemPass::IsTargetType(const Instruction* typeInst) const {
  50. if (IsBaseTargetType(typeInst)) return true;
  51. if (typeInst->opcode() == SpvOpTypeArray) {
  52. if (!IsTargetType(
  53. get_def_use_mgr()->GetDef(typeInst->GetSingleWordOperand(1)))) {
  54. return false;
  55. }
  56. return true;
  57. }
  58. if (typeInst->opcode() != SpvOpTypeStruct) return false;
  59. // All struct members must be math type
  60. return typeInst->WhileEachInId([this](const uint32_t* tid) {
  61. Instruction* compTypeInst = get_def_use_mgr()->GetDef(*tid);
  62. if (!IsTargetType(compTypeInst)) return false;
  63. return true;
  64. });
  65. }
  66. bool MemPass::IsNonPtrAccessChain(const SpvOp opcode) const {
  67. return opcode == SpvOpAccessChain || opcode == SpvOpInBoundsAccessChain;
  68. }
  69. bool MemPass::IsPtr(uint32_t ptrId) {
  70. uint32_t varId = ptrId;
  71. Instruction* ptrInst = get_def_use_mgr()->GetDef(varId);
  72. while (ptrInst->opcode() == SpvOpCopyObject) {
  73. varId = ptrInst->GetSingleWordInOperand(kCopyObjectOperandInIdx);
  74. ptrInst = get_def_use_mgr()->GetDef(varId);
  75. }
  76. const SpvOp op = ptrInst->opcode();
  77. if (op == SpvOpVariable || IsNonPtrAccessChain(op)) return true;
  78. if (op != SpvOpFunctionParameter) return false;
  79. const uint32_t varTypeId = ptrInst->type_id();
  80. const Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId);
  81. return varTypeInst->opcode() == SpvOpTypePointer;
  82. }
  83. Instruction* MemPass::GetPtr(uint32_t ptrId, uint32_t* varId) {
  84. *varId = ptrId;
  85. Instruction* ptrInst = get_def_use_mgr()->GetDef(*varId);
  86. Instruction* varInst;
  87. if (ptrInst->opcode() != SpvOpVariable &&
  88. ptrInst->opcode() != SpvOpFunctionParameter) {
  89. varInst = ptrInst->GetBaseAddress();
  90. } else {
  91. varInst = ptrInst;
  92. }
  93. if (varInst->opcode() == SpvOpVariable) {
  94. *varId = varInst->result_id();
  95. } else {
  96. *varId = 0;
  97. }
  98. while (ptrInst->opcode() == SpvOpCopyObject) {
  99. uint32_t temp = ptrInst->GetSingleWordInOperand(0);
  100. ptrInst = get_def_use_mgr()->GetDef(temp);
  101. }
  102. return ptrInst;
  103. }
  104. Instruction* MemPass::GetPtr(Instruction* ip, uint32_t* varId) {
  105. assert(ip->opcode() == SpvOpStore || ip->opcode() == SpvOpLoad ||
  106. ip->opcode() == SpvOpImageTexelPointer || ip->IsAtomicWithLoad());
  107. // All of these opcode place the pointer in position 0.
  108. const uint32_t ptrId = ip->GetSingleWordInOperand(0);
  109. return GetPtr(ptrId, varId);
  110. }
  111. bool MemPass::HasOnlyNamesAndDecorates(uint32_t id) const {
  112. return get_def_use_mgr()->WhileEachUser(id, [this](Instruction* user) {
  113. SpvOp op = user->opcode();
  114. if (op != SpvOpName && !IsNonTypeDecorate(op)) {
  115. return false;
  116. }
  117. return true;
  118. });
  119. }
  120. void MemPass::KillAllInsts(BasicBlock* bp, bool killLabel) {
  121. bp->KillAllInsts(killLabel);
  122. }
  123. bool MemPass::HasLoads(uint32_t varId) const {
  124. return !get_def_use_mgr()->WhileEachUser(varId, [this](Instruction* user) {
  125. SpvOp op = user->opcode();
  126. // TODO(): The following is slightly conservative. Could be
  127. // better handling of non-store/name.
  128. if (IsNonPtrAccessChain(op) || op == SpvOpCopyObject) {
  129. if (HasLoads(user->result_id())) {
  130. return false;
  131. }
  132. } else if (op != SpvOpStore && op != SpvOpName && !IsNonTypeDecorate(op)) {
  133. return false;
  134. }
  135. return true;
  136. });
  137. }
  138. bool MemPass::IsLiveVar(uint32_t varId) const {
  139. const Instruction* varInst = get_def_use_mgr()->GetDef(varId);
  140. // assume live if not a variable eg. function parameter
  141. if (varInst->opcode() != SpvOpVariable) return true;
  142. // non-function scope vars are live
  143. const uint32_t varTypeId = varInst->type_id();
  144. const Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId);
  145. if (varTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx) !=
  146. SpvStorageClassFunction)
  147. return true;
  148. // test if variable is loaded from
  149. return HasLoads(varId);
  150. }
  151. void MemPass::AddStores(uint32_t ptr_id, std::queue<Instruction*>* insts) {
  152. get_def_use_mgr()->ForEachUser(ptr_id, [this, insts](Instruction* user) {
  153. SpvOp op = user->opcode();
  154. if (IsNonPtrAccessChain(op)) {
  155. AddStores(user->result_id(), insts);
  156. } else if (op == SpvOpStore) {
  157. insts->push(user);
  158. }
  159. });
  160. }
  161. void MemPass::DCEInst(Instruction* inst,
  162. const std::function<void(Instruction*)>& call_back) {
  163. std::queue<Instruction*> deadInsts;
  164. deadInsts.push(inst);
  165. while (!deadInsts.empty()) {
  166. Instruction* di = deadInsts.front();
  167. // Don't delete labels
  168. if (di->opcode() == SpvOpLabel) {
  169. deadInsts.pop();
  170. continue;
  171. }
  172. // Remember operands
  173. std::set<uint32_t> ids;
  174. di->ForEachInId([&ids](uint32_t* iid) { ids.insert(*iid); });
  175. uint32_t varId = 0;
  176. // Remember variable if dead load
  177. if (di->opcode() == SpvOpLoad) (void)GetPtr(di, &varId);
  178. if (call_back) {
  179. call_back(di);
  180. }
  181. context()->KillInst(di);
  182. // For all operands with no remaining uses, add their instruction
  183. // to the dead instruction queue.
  184. for (auto id : ids)
  185. if (HasOnlyNamesAndDecorates(id)) {
  186. Instruction* odi = get_def_use_mgr()->GetDef(id);
  187. if (context()->IsCombinatorInstruction(odi)) deadInsts.push(odi);
  188. }
  189. // if a load was deleted and it was the variable's
  190. // last load, add all its stores to dead queue
  191. if (varId != 0 && !IsLiveVar(varId)) AddStores(varId, &deadInsts);
  192. deadInsts.pop();
  193. }
  194. }
  195. MemPass::MemPass() {}
  196. bool MemPass::HasOnlySupportedRefs(uint32_t varId) {
  197. return get_def_use_mgr()->WhileEachUser(varId, [this](Instruction* user) {
  198. SpvOp op = user->opcode();
  199. if (op != SpvOpStore && op != SpvOpLoad && op != SpvOpName &&
  200. !IsNonTypeDecorate(op)) {
  201. return false;
  202. }
  203. return true;
  204. });
  205. }
  206. uint32_t MemPass::Type2Undef(uint32_t type_id) {
  207. const auto uitr = type2undefs_.find(type_id);
  208. if (uitr != type2undefs_.end()) return uitr->second;
  209. const uint32_t undefId = TakeNextId();
  210. if (undefId == 0) {
  211. return 0;
  212. }
  213. std::unique_ptr<Instruction> undef_inst(
  214. new Instruction(context(), SpvOpUndef, type_id, undefId, {}));
  215. get_def_use_mgr()->AnalyzeInstDefUse(&*undef_inst);
  216. get_module()->AddGlobalValue(std::move(undef_inst));
  217. type2undefs_[type_id] = undefId;
  218. return undefId;
  219. }
  220. bool MemPass::IsTargetVar(uint32_t varId) {
  221. if (varId == 0) {
  222. return false;
  223. }
  224. if (seen_non_target_vars_.find(varId) != seen_non_target_vars_.end())
  225. return false;
  226. if (seen_target_vars_.find(varId) != seen_target_vars_.end()) return true;
  227. const Instruction* varInst = get_def_use_mgr()->GetDef(varId);
  228. if (varInst->opcode() != SpvOpVariable) return false;
  229. const uint32_t varTypeId = varInst->type_id();
  230. const Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId);
  231. if (varTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx) !=
  232. SpvStorageClassFunction) {
  233. seen_non_target_vars_.insert(varId);
  234. return false;
  235. }
  236. const uint32_t varPteTypeId =
  237. varTypeInst->GetSingleWordInOperand(kTypePointerTypeIdInIdx);
  238. Instruction* varPteTypeInst = get_def_use_mgr()->GetDef(varPteTypeId);
  239. if (!IsTargetType(varPteTypeInst)) {
  240. seen_non_target_vars_.insert(varId);
  241. return false;
  242. }
  243. seen_target_vars_.insert(varId);
  244. return true;
  245. }
  246. // Remove all |phi| operands coming from unreachable blocks (i.e., blocks not in
  247. // |reachable_blocks|). There are two types of removal that this function can
  248. // perform:
  249. //
  250. // 1- Any operand that comes directly from an unreachable block is completely
  251. // removed. Since the block is unreachable, the edge between the unreachable
  252. // block and the block holding |phi| has been removed.
  253. //
  254. // 2- Any operand that comes via a live block and was defined at an unreachable
  255. // block gets its value replaced with an OpUndef value. Since the argument
  256. // was generated in an unreachable block, it no longer exists, so it cannot
  257. // be referenced. However, since the value does not reach |phi| directly
  258. // from the unreachable block, the operand cannot be removed from |phi|.
  259. // Therefore, we replace the argument value with OpUndef.
  260. //
  261. // For example, in the switch() below, assume that we want to remove the
  262. // argument with value %11 coming from block %41.
  263. //
  264. // [ ... ]
  265. // %41 = OpLabel <--- Unreachable block
  266. // %11 = OpLoad %int %y
  267. // [ ... ]
  268. // OpSelectionMerge %16 None
  269. // OpSwitch %12 %16 10 %13 13 %14 18 %15
  270. // %13 = OpLabel
  271. // OpBranch %16
  272. // %14 = OpLabel
  273. // OpStore %outparm %int_14
  274. // OpBranch %16
  275. // %15 = OpLabel
  276. // OpStore %outparm %int_15
  277. // OpBranch %16
  278. // %16 = OpLabel
  279. // %30 = OpPhi %int %11 %41 %int_42 %13 %11 %14 %11 %15
  280. //
  281. // Since %41 is now an unreachable block, the first operand of |phi| needs to
  282. // be removed completely. But the operands (%11 %14) and (%11 %15) cannot be
  283. // removed because %14 and %15 are reachable blocks. Since %11 no longer exist,
  284. // in those arguments, we replace all references to %11 with an OpUndef value.
  285. // This results in |phi| looking like:
  286. //
  287. // %50 = OpUndef %int
  288. // [ ... ]
  289. // %30 = OpPhi %int %int_42 %13 %50 %14 %50 %15
  290. void MemPass::RemovePhiOperands(
  291. Instruction* phi, const std::unordered_set<BasicBlock*>& reachable_blocks) {
  292. std::vector<Operand> keep_operands;
  293. uint32_t type_id = 0;
  294. // The id of an undefined value we've generated.
  295. uint32_t undef_id = 0;
  296. // Traverse all the operands in |phi|. Build the new operand vector by adding
  297. // all the original operands from |phi| except the unwanted ones.
  298. for (uint32_t i = 0; i < phi->NumOperands();) {
  299. if (i < 2) {
  300. // The first two arguments are always preserved.
  301. keep_operands.push_back(phi->GetOperand(i));
  302. ++i;
  303. continue;
  304. }
  305. // The remaining Phi arguments come in pairs. Index 'i' contains the
  306. // variable id, index 'i + 1' is the originating block id.
  307. assert(i % 2 == 0 && i < phi->NumOperands() - 1 &&
  308. "malformed Phi arguments");
  309. BasicBlock* in_block = cfg()->block(phi->GetSingleWordOperand(i + 1));
  310. if (reachable_blocks.find(in_block) == reachable_blocks.end()) {
  311. // If the incoming block is unreachable, remove both operands as this
  312. // means that the |phi| has lost an incoming edge.
  313. i += 2;
  314. continue;
  315. }
  316. // In all other cases, the operand must be kept but may need to be changed.
  317. uint32_t arg_id = phi->GetSingleWordOperand(i);
  318. Instruction* arg_def_instr = get_def_use_mgr()->GetDef(arg_id);
  319. BasicBlock* def_block = context()->get_instr_block(arg_def_instr);
  320. if (def_block &&
  321. reachable_blocks.find(def_block) == reachable_blocks.end()) {
  322. // If the current |phi| argument was defined in an unreachable block, it
  323. // means that this |phi| argument is no longer defined. Replace it with
  324. // |undef_id|.
  325. if (!undef_id) {
  326. type_id = arg_def_instr->type_id();
  327. undef_id = Type2Undef(type_id);
  328. }
  329. keep_operands.push_back(
  330. Operand(spv_operand_type_t::SPV_OPERAND_TYPE_ID, {undef_id}));
  331. } else {
  332. // Otherwise, the argument comes from a reachable block or from no block
  333. // at all (meaning that it was defined in the global section of the
  334. // program). In both cases, keep the argument intact.
  335. keep_operands.push_back(phi->GetOperand(i));
  336. }
  337. keep_operands.push_back(phi->GetOperand(i + 1));
  338. i += 2;
  339. }
  340. context()->ForgetUses(phi);
  341. phi->ReplaceOperands(keep_operands);
  342. context()->AnalyzeUses(phi);
  343. }
  344. void MemPass::RemoveBlock(Function::iterator* bi) {
  345. auto& rm_block = **bi;
  346. // Remove instructions from the block.
  347. rm_block.ForEachInst([&rm_block, this](Instruction* inst) {
  348. // Note that we do not kill the block label instruction here. The label
  349. // instruction is needed to identify the block, which is needed by the
  350. // removal of phi operands.
  351. if (inst != rm_block.GetLabelInst()) {
  352. context()->KillInst(inst);
  353. }
  354. });
  355. // Remove the label instruction last.
  356. auto label = rm_block.GetLabelInst();
  357. context()->KillInst(label);
  358. *bi = bi->Erase();
  359. }
  360. bool MemPass::RemoveUnreachableBlocks(Function* func) {
  361. bool modified = false;
  362. // Mark reachable all blocks reachable from the function's entry block.
  363. std::unordered_set<BasicBlock*> reachable_blocks;
  364. std::unordered_set<BasicBlock*> visited_blocks;
  365. std::queue<BasicBlock*> worklist;
  366. reachable_blocks.insert(func->entry().get());
  367. // Initially mark the function entry point as reachable.
  368. worklist.push(func->entry().get());
  369. auto mark_reachable = [&reachable_blocks, &visited_blocks, &worklist,
  370. this](uint32_t label_id) {
  371. auto successor = cfg()->block(label_id);
  372. if (visited_blocks.count(successor) == 0) {
  373. reachable_blocks.insert(successor);
  374. worklist.push(successor);
  375. visited_blocks.insert(successor);
  376. }
  377. };
  378. // Transitively mark all blocks reachable from the entry as reachable.
  379. while (!worklist.empty()) {
  380. BasicBlock* block = worklist.front();
  381. worklist.pop();
  382. // All the successors of a live block are also live.
  383. static_cast<const BasicBlock*>(block)->ForEachSuccessorLabel(
  384. mark_reachable);
  385. // All the Merge and ContinueTarget blocks of a live block are also live.
  386. block->ForMergeAndContinueLabel(mark_reachable);
  387. }
  388. // Update operands of Phi nodes that reference unreachable blocks.
  389. for (auto& block : *func) {
  390. // If the block is about to be removed, don't bother updating its
  391. // Phi instructions.
  392. if (reachable_blocks.count(&block) == 0) {
  393. continue;
  394. }
  395. // If the block is reachable and has Phi instructions, remove all
  396. // operands from its Phi instructions that reference unreachable blocks.
  397. // If the block has no Phi instructions, this is a no-op.
  398. block.ForEachPhiInst([&reachable_blocks, this](Instruction* phi) {
  399. RemovePhiOperands(phi, reachable_blocks);
  400. });
  401. }
  402. // Erase unreachable blocks.
  403. for (auto ebi = func->begin(); ebi != func->end();) {
  404. if (reachable_blocks.count(&*ebi) == 0) {
  405. RemoveBlock(&ebi);
  406. modified = true;
  407. } else {
  408. ++ebi;
  409. }
  410. }
  411. return modified;
  412. }
  413. bool MemPass::CFGCleanup(Function* func) {
  414. bool modified = false;
  415. modified |= RemoveUnreachableBlocks(func);
  416. return modified;
  417. }
  418. void MemPass::CollectTargetVars(Function* func) {
  419. seen_target_vars_.clear();
  420. seen_non_target_vars_.clear();
  421. type2undefs_.clear();
  422. // Collect target (and non-) variable sets. Remove variables with
  423. // non-load/store refs from target variable set
  424. for (auto& blk : *func) {
  425. for (auto& inst : blk) {
  426. switch (inst.opcode()) {
  427. case SpvOpStore:
  428. case SpvOpLoad: {
  429. uint32_t varId;
  430. (void)GetPtr(&inst, &varId);
  431. if (!IsTargetVar(varId)) break;
  432. if (HasOnlySupportedRefs(varId)) break;
  433. seen_non_target_vars_.insert(varId);
  434. seen_target_vars_.erase(varId);
  435. } break;
  436. default:
  437. break;
  438. }
  439. }
  440. }
  441. }
  442. } // namespace opt
  443. } // namespace spvtools